#!/usr/bin/env python3
import sys
import os
import numpy as np
import matplotlib.pyplot as plt

def load_signal(label):
    """
    Load a signal from <label>.csv.
    CSV format: time, value (no header), comma-separated.
    """
    filename = f"{label}.csv"
    if not os.path.exists(filename):
        raise FileNotFoundError(f"File not found: {filename}")
    data = np.loadtxt(filename, delimiter=",")
    if data.ndim != 2 or data.shape[1] < 2:
        raise ValueError(f"File {filename} does not have two columns (time, value).")
    t = data[:, 0]
    y = data[:, 1]
    return t, y

def main():
    if len(sys.argv) < 3 or len(sys.argv[1:]) % 2 != 0:
        print("Usage:")
        print("  python SpectroStack.py LABEL1 LABEL2 ... LABELN QE1 QE2 ... QEN")
        print("Example:")
        print("  python SpectroStack.py "
              "L_HeI_447 L_OI_777 L_NII_568 L_CII_514 L_Hbeta_486 "
              "L_broad L_HeI_587 L_Halpha_656 "
              "45 45 30 30 45 10 30 45")
        sys.exit(1)

    args = sys.argv[1:]
    n = len(args) // 2
    labels = args[:n]
    qe_values = args[n:]

    # Convert QE to float and to a 0–1 scale
    try:
        qe = [float(x) / 100.0 for x in qe_values]
    except ValueError:
        print("All QE values must be numeric (percent).")
        sys.exit(1)

    if "L_broad" not in labels:
        print("Error: One of the labels must be 'L_broad' (reference broadband channel).")
        sys.exit(1)

    # Map label -> QE factor
    qe_map = {label: q for label, q in zip(labels, qe)}

    # --- 1) Načteme všechny signály tak, jak jsou ---
    time_map = {}
    raw_map = {}

    for label in labels:
        t, y = load_signal(label)
        time_map[label] = t
        raw_map[label] = y

    # --- 2) Zvolíme referenční časovou osu podle L_broad ---
    t_ref = time_map["L_broad"]

    # --- 3) Interpolujeme všechny signály na časovou osu L_broad ---
    signals_scaled = {}
    for label in labels:
        t = time_map[label]
        y = raw_map[label]

        # Interpolace na t_ref, mimo rozsah nastavíme 0
        y_interp = np.interp(t_ref, t, y, left=0.0, right=0.0)

        # Aplikujeme kvantovou účinnost
        signals_scaled[label] = y_interp * qe_map[label]

    # Reference broadband signal
    broad = signals_scaled["L_broad"]

    # Narrowband channels = vše kromě L_broad
    narrow_labels = [lab for lab in labels if lab != "L_broad"]
    narrow_arrays = [signals_scaled[lab] for lab in narrow_labels]

    # Pro zajímavost: součet úzkopásmových příspěvků
    if narrow_arrays:
        narrow_sum = np.sum(np.vstack(narrow_arrays), axis=0)
    else:
        narrow_sum = np.zeros_like(broad)

    # --- 4) Kreslení ---
    fig, ax = plt.subplots(figsize=(10, 6))

    # Barevný sendvič z úzkopásmových kanálů
    if narrow_arrays:
        ax.stackplot(t_ref, *narrow_arrays, labels=narrow_labels, alpha=0.8)

    # Referenční broadband L_broad jako čára
    ax.plot(t_ref, broad, linewidth=2.0, label="L_broad (reference)")

    ax.set_xlabel("Time [s]")
    ax.set_ylabel("Scaled light emission [arb. units]")
    ax.set_title("Spectral contributions scaled by quantum efficiency")

    ax.legend(loc="upper right")
    fig.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()
