import matplotlib
matplotlib.use('Agg')  # neinteraktivní backend

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import sys
from matplotlib.ticker import MaxNLocator, FormatStrFormatter

# --- Parametr z CLI ---
# Použití: python3 PlotIt.py <shot_no>
shot_no = sys.argv[1] if len(sys.argv) > 1 else "?"

# Parametry
channels = [f'ch{i}.csv' for i in range(1, 13)]
fig, axs = plt.subplots(6, 2, figsize=(12, 10), sharex=True)
fig.suptitle(f"GOLEM Shot {shot_no}")
axs = np.array(axs)  # jistota indexování

# Rozdělení na R a L
axs_R = [axs[i][0] for i in range(6)]
axs_L = [axs[i][1] for i in range(6)]
group_R = channels[:6]   # ch1–ch6
group_L = channels[6:]   # ch7–ch12

# Pravý sloupec: štítky a stupnice napravo
for ax in axs_L:
    ax.yaxis.set_label_position("right")
    ax.yaxis.tick_right()

# Funkce pro vykreslení skupiny
def plot_group(axs, group, ref_file, scale_std, labels_prefix):
    # Načtení referenčního signálu
    ref_data = pd.read_csv(ref_file, header=None)
    y_ref = ref_data[1]
    mean = y_ref.mean()
    std = y_ref.std()
    ymin = mean - scale_std * std
    ymax = mean + scale_std * std
    if ymin == ymax:
        ymax = ymin + 1e-3  # proti degeneraci

    for ax, fname, label in zip(axs, group, range(1, 7)):
        if not os.path.exists(fname):
            continue
        data = pd.read_csv(fname, header=None)
        t, y = data[0].to_numpy(), data[1].to_numpy()

        ax.plot(t * 1000, y, linestyle='-', linewidth=0.8)
        ax.set_ylabel(f"{labels_prefix}{label} [V]")  # jednotky

        # Rozumné limity a stupnice
        ax.set_ylim(ymin, ymax)
        ax.yaxis.set_major_locator(MaxNLocator(nbins=3, prune=None))  # aspoň ~3 dílky
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        ax.grid(True, linestyle=':', linewidth=0.5)

# Vykreslení obou skupin
plot_group(axs_R, group_R, 'ch1.csv', 3.0, 'R')
plot_group(axs_L, group_L, 'ch7.csv', 3.0, 'L')

# Osa X
for ax in axs_R[-1:] + axs_L[-1:]:
    ax.set_xlabel("Time [ms]")
    ax.set_xticks([10, 20, 30, 40])

plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.savefig("ScreenShotAll.png", dpi=150)
# plt.show()
