#!/usr/bin/env python3
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import sys
from matplotlib.ticker import MaxNLocator, FormatStrFormatter

if len(sys.argv) < 4:
    print("Usage: PlotRawData.py <shot_no> <diagnostic_name> '<label1>' '<label2>' ...")
    sys.exit(1)

shot_no = sys.argv[1]
diagnostic_name = sys.argv[2]
labels = sys.argv[3:]

base_dir = "DAS_raw_data_dir"
if not os.path.exists(base_dir):
    print(f"❌ Error: directory not found: {base_dir}")
    sys.exit(1)

n_channels = len(labels)
file_names = [os.path.join(base_dir, f"ch{i+1}.csv") for i in range(n_channels)]

# --- Layout ---
ncols = 2 if n_channels > 1 else 1
nrows = int(np.ceil(n_channels / ncols))
fig, axs = plt.subplots(nrows, ncols, figsize=(12, 2 * nrows), sharex=True)

# 🩹 ZAJIŠTĚNÍ, že axs je 2D pole
if isinstance(axs, plt.Axes):
    axs = np.array([[axs]])
elif axs.ndim == 1:
    axs = axs.reshape(-1, 1)
fig.suptitle(f"{diagnostic_name} – GOLEM Shot #{shot_no}", fontsize=14)

#axes_flat = axs.flatten() # po radcich
axes_flat = axs.T.flatten()

# --- Global y scale ---
all_y = []
for fname in file_names:
    if os.path.exists(fname):
        y = pd.read_csv(fname, header=None)[1]
        all_y.extend(y)
if all_y:
    y_mean, y_std = np.mean(all_y), np.std(all_y)
    ymin, ymax = y_mean - 3*y_std, y_mean + 3*y_std
else:
    ymin, ymax = -1, 1

# --- Plot channels ---
for i, (fname, label) in enumerate(zip(file_names, labels)):
    ax = axes_flat[i]

    if not os.path.exists(fname):
        print(f"⚠️ Warning: {fname} not found, skipping.")
        ax.set_visible(False)
        continue

    data = pd.read_csv(fname, header=None)
    t, y = data[0].to_numpy(), data[1].to_numpy()
    ax.plot(t * 1000, y, lw=0.8)
    ax.set_ylabel(f"{label} [V]", fontsize=8)
    ax.set_ylim(ymin, ymax)
    ax.yaxis.set_major_locator(MaxNLocator(nbins=3))
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    ax.grid(True, linestyle=':', linewidth=0.5)

# Right column → ticks right
if ncols == 2:
    for ax in axs[:, 1]:
        ax.yaxis.set_label_position("right")
        ax.yaxis.tick_right()

# X-axis labels
for ax in axs[-1, :]:
    if ax.get_visible():
        ax.set_xlabel("Time [ms]")

plt.tight_layout(rect=[0, 0, 1, 0.95])
out_file = f"ScreenShot_{diagnostic_name.replace(' ', '_')}_{shot_no}.png"
plt.savefig(out_file, dpi=150)
out_file = f"DAS_raw_data_dir/ScreenShotAll.png"
plt.savefig(out_file, dpi=150)
out_file = f"DAS_raw_data_dir/rawdata.jpg"
plt.savefig(out_file, dpi=20)
print(f"✅ Saved plot: {out_file}")


#python ../../Infrastructure/Homepage/tools/PlotRawData.py 1000 'DRP' DRP-R1 DRP-R2 DRP-R3 DRP-R4 DRP-R5 DRP-R6 DRP-L1 DRP-L2 DRP-L3 DRP-L4 DRP-L5 DRP-L6
