#!/usr/bin/env python3
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import sys
from matplotlib.ticker import MaxNLocator, FormatStrFormatter

# ------------------------------------------------------------
# CLI:
#   PlotRawData.py <shot_no> <diagnostic_name> <y_mode> '<label1>' ...
#
# y_mode:
#   - "GlobalSigma"   → common y-scale, mean ± 3σ over all channels
#   - "GlobalMaximum" → common y-scale, global min/max over all channels
#   - "Full"          → individual y-scale per channel
# ------------------------------------------------------------
if len(sys.argv) < 5:
    print("Usage: PlotRawData.py <shot_no> <diagnostic_name> <y_mode> '<label1>' '<label2>' ...")
    sys.exit(1)

shot_no = sys.argv[1]
diagnostic_name = sys.argv[2]
y_mode = sys.argv[3]
labels = sys.argv[4:]

# --- Input files: current working directory ----------------------
# Label "R1" → file "R1.csv", etc.
file_names = [f"{label}.csv" for label in labels]
n_channels = len(labels)

# --- Plasma timing -----------------------------------------------
# Relative to current directory:
#   ../PlasmaDetection/Results/t_plasma_start  (values in ms)
plasma_results_dir = os.path.abspath(
    os.path.join(os.getcwd(), "../PlasmaDetection", "Results")
)

def read_scalar(filename):
    """Read a single float value from a small text file. Return None on error."""
    path = os.path.join(plasma_results_dir, filename)
    try:
        with open(path, "r") as f:
            return float(f.read().strip())
    except FileNotFoundError:
        print(f"⚠️ Warning: plasma file not found: {path}")
        return None
    except ValueError:
        print(f"⚠️ Warning: cannot parse plasma value from: {path}")
        return None

# Times are stored in milliseconds; b_plasma is a dimensionless flag
t_plasma_start_ms    = read_scalar("t_plasma_start")
t_plasma_end_ms      = read_scalar("t_plasma_end")
t_plasma_qs_start_ms = read_scalar("t_plasma_qs_start")
t_plasma_qs_end_ms   = read_scalar("t_plasma_qs_end")
b_plasma             = read_scalar("b_plasma")

# Determine whether plasma exists
has_plasma = True
if b_plasma is not None and b_plasma == 0.0:
    has_plasma = False

# Convert to seconds for comparison with CSV time axis (which is in seconds)
t_plasma_start_s    = t_plasma_start_ms    / 1000.0 if (has_plasma and t_plasma_start_ms    is not None) else None
t_plasma_end_s      = t_plasma_end_ms      / 1000.0 if (has_plasma and t_plasma_end_ms      is not None) else None
t_plasma_qs_start_s = t_plasma_qs_start_ms / 1000.0 if (has_plasma and t_plasma_qs_start_ms is not None) else None
t_plasma_qs_end_s   = t_plasma_qs_end_ms   / 1000.0 if (has_plasma and t_plasma_qs_end_ms   is not None) else None

# --- Layout -------------------------------------------------------
ncols = 2 if n_channels > 1 else 1
nrows = int(np.ceil(n_channels / ncols))
fig, axs = plt.subplots(nrows, ncols, figsize=(12, 2 * nrows), sharex=True)

# Ensure axs is a 2D array
if isinstance(axs, plt.Axes):
    axs = np.array([[axs]])
elif axs.ndim == 1:
    axs = axs.reshape(-1, 1)

fig.suptitle(f"{diagnostic_name} – GOLEM Shot #{shot_no}", fontsize=14)

# Flatten axes column-wise (channel ordering by columns)
# axes_flat = axs.flatten()   # row-wise
axes_flat = axs.T.flatten()   # column-wise

# --- Global y-scale statistics (if needed) -----------------------
ymin_global, ymax_global = -1.0, 1.0

all_y = []

# Collect only plasma-relevant data for global statistics
if y_mode in ("GlobalSigma", "GlobalMaximum"):
    for fname in file_names:
        if not os.path.exists(fname):
            continue

        data = pd.read_csv(fname, header=None)
        t = data[0].to_numpy()      # seconds
        y = data[1].to_numpy()

        if has_plasma and (t_plasma_start_s is not None) and (t_plasma_end_s is not None):
            # use only the plasma interval to avoid pre-trigger spikes
            mask = (t >= t_plasma_start_s) & (t <= t_plasma_end_s)
            y_sel = y[mask]
        else:
            # vacuum case – use the whole signal but time-limited 0–30 ms
            t_ms = t * 1000.0
            mask = (t_ms >= 0) & (t_ms <= 30)
            y_sel = y[mask]

        if y_sel.size > 0:
            all_y.extend(y_sel)


if all_y:
    all_y = np.asarray(all_y)
    global_min = float(np.min(all_y))
    global_max = float(np.max(all_y))

    if y_mode == "GlobalSigma":
        y_mean = float(np.mean(all_y))
        y_std = float(np.std(all_y))
        ymin_global = y_mean - 3.0 * y_std
        ymax_global = y_mean + 3.0 * y_std
    elif y_mode == "GlobalMaximum":
        # Use the absolute global min/max; add a small margin to avoid clipping
        if global_min == global_max:
            ymin_global = global_min - 1.0
            ymax_global = global_max + 1.0
        else:
            margin = 0.05 * (global_max - global_min)
            ymin_global = global_min - margin
            ymax_global = global_max + margin

# --- Plot channels -----------------------------------------------
for i, (fname, label) in enumerate(zip(file_names, labels)):
    ax = axes_flat[i]

    if not os.path.exists(fname):
        print(f"⚠️ Warning: {fname} not found, skipping.")
        ax.set_visible(False)
        continue

    data = pd.read_csv(fname, header=None)
    t = data[0].to_numpy()      # [s]
    y = data[1].to_numpy()
    t_ms = t * 1000.0           # for plotting

    # Select time interval
    if has_plasma and (t_plasma_start_s is not None) and (t_plasma_end_s is not None):
        # Plot only during plasma existence
        mask = (t >= t_plasma_start_s) & (t <= t_plasma_end_s)
        x_min = t_plasma_start_ms
        x_max = t_plasma_end_ms
    elif not has_plasma:
        # Vacuum shot: fixed 0–30 ms time window
        mask = (t_ms >= 0.0) & (t_ms <= 30.0)
        x_min, x_max = 0.0, 30.0
    else:
        # Fallback: unknown plasma times → show full time range
        mask = np.ones_like(t, dtype=bool)
        x_min, x_max = float(t_ms.min()), float(t_ms.max())

    t_plot = t_ms[mask]
    y_plot = y[mask]

    if t_plot.size == 0:
        # If the mask removed everything, fall back to full data
        t_plot = t_ms
        y_plot = y
        x_min, x_max = float(t_ms.min()), float(t_ms.max())

    ax.plot(t_plot, y_plot, lw=0.8)
    ax.set_ylabel(f"{label} [V]", fontsize=8)
    ax.set_xlim(x_min, x_max)

    # --- y-limits according to y_mode -----------------------------
    if y_mode in ("GlobalSigma", "GlobalMaximum"):
        ax.set_ylim(ymin_global, ymax_global)
    elif y_mode == "Full":
        # Individual y-scale for each channel
        y_min_ch = float(np.min(y_plot))
        y_max_ch = float(np.max(y_plot))
        if y_min_ch == y_max_ch:
            y_min_ch -= 1.0
            y_max_ch += 1.0
        else:
            margin = 0.05 * (y_max_ch - y_min_ch)
            y_min_ch -= margin
            y_max_ch += margin
        ax.set_ylim(y_min_ch, y_max_ch)
    else:
        # Unknown mode → default to global sigma range (if available)
        ax.set_ylim(ymin_global, ymax_global)

    ax.yaxis.set_major_locator(MaxNLocator(nbins=3))
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    ax.grid(True, linestyle=':', linewidth=0.5)

    # Vertical lines for quasi-stationary interval (only if plasma exists)
    if has_plasma and (t_plasma_qs_start_ms is not None):
        ax.axvline(t_plasma_qs_start_ms, linestyle='--', linewidth=0.8)
    if has_plasma and (t_plasma_qs_end_ms is not None):
        ax.axvline(t_plasma_qs_end_ms, linestyle='--', linewidth=0.8)

# Right column → ticks on the right side
if ncols == 2:
    for ax in axs[:, 1]:
        ax.yaxis.set_label_position("right")
        ax.yaxis.tick_right()

# X-axis labels on the bottom row
for ax in axs[-1, :]:
    if ax.get_visible():
        ax.set_xlabel("Time [ms]")

plt.tight_layout(rect=[0, 0, 1, 0.95])

out_file = f"ScreenShot_{diagnostic_name.replace(' ', '_')}_{shot_no}.png"
plt.savefig(out_file, dpi=150)
plt.savefig("ScreenShotAll.png", dpi=150)
plt.savefig("rawdata.jpg", dpi=20)
print(f"✅ Saved plot: {out_file}")
