#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib
matplotlib.rcParams['backend'] = 'Agg'
matplotlib.rc('font',  size='10')
matplotlib.rc('text', usetex=True)  # FIXME !! nicer but slower !!!
from matplotlib.pyplot import *
import os
#import  scipy.signal
from scipy import signal
from scipy.signal import fftconvolve
from multiprocessing import Process, Pool, cpu_count
import time
from numpy import *
from CMWT import *
from SoundGenerator import *
from pygolem_lite import Shot
from pygolem_lite.modules import list2array,deconvolveExp,save_adv,saveconst
from scipy.stats.mstats  import mquantiles

#from scipy.stats.mstats import mquantiles

#RingCoilOrientation = (1,1,-1,-1, 1,-1,1,1,1,-1,-1,1,1,1,1,-1) #TODO důležité
#RingCoilOrientation = (-1,1,1,1,-1,1,-1,-1,-1,1,-1,1,-1,-1,-1,-1) #TODO důležité
RingCoilOrientation = (-1,-1,1,1,-1,1,-1,1,-1,1,-1,-1,1,-1,-1,-1)
RingCoilOrientation = (-1,1,1,1,-1,1,-1,-1,-1,1,-1,1,-1,-1,-1,-1)

#AEff = (68e-4, 140e-4, 138e-4, 140e-4, 68e-4, 134e-4,134e-4, 142e-4, 67e-4, 142e-4, 140e-4, 138e-4, 76e-4, 142e-4, 139e-4, 139e-4) #in m^2


AEff =  [  1.56253461e-02,   1.50618720e-02  , 1.45473064e-02  , 1.27421132e-02,
   1.44558132e-02,   1.80693924e-02  , 1.56932626e-02 ,  1.05514271e-02,
   3.90988379e-03 ,  5.96870657e-03 ,  8.48762208e-03 ,  8.47233307e-05,
   1.70823107e-02  , 8.08770169e-03  , 1.44213854e-02   ,1.58269600e-02]
       
AEff = [ 0.00209134 , 0.00446023 , 0.00428289  ,0.00374073 , 0.00320856 , 0.00182771,
  0.00267794,  0.00284156 , 0.00145525  ,0.00222005  ,0.00206269 , 0.00108452,
  0.00054154 , 0.00078395 , 0.00186135 , 0.00274475]



#AEff = (68e-4, 140e-4, 138e-4, 140e-4, 130e-4, 134e-4,134e-4, 142e-4, 67e-4, 142e-4, 140e-4, 140e-4, 140e-4, 142e-4, 139e-4, 139e-4) #in m^2


MirnovCoil = (1,1,1,1)#TODO důležité


matplotlib.rcParams['xtick.direction'] = 'out'
matplotlib.rcParams['ytick.direction'] = 'out'

matplotlib.rcParams['xtick.major.size'] = 10
matplotlib.rcParams['xtick.minor.size'] = 7

matplotlib.rcParams['ytick.major.size'] = 10
matplotlib.rcParams['ytick.minor.size'] = 7

def mad(x,axis = -1):
    #Median absolute deviation
    m = np.median(x,axis=axis)
    x = np.swapaxes(x, 0, axis)
    s = np.median(abs(x-m),axis=0)*1.4826
    return s
 
def stft(x, fs, framesz, hop):
    framesamp = int(2**np.ceil(np.log2(framesz*fs)))
    hopsamp = int(hop*fs)
    print hopsamp, framesamp, len(x)
    print xrange(0, len(x)-framesamp, hopsamp)
    w = np.hamming(framesamp)
    X = np.array([np.fft.rfft(w*x[i:i+framesamp]) 
                     for i in xrange(0, len(x)-framesamp, hopsamp)], copy=False)
    return X                   

#def PrepareData(data,typ):
    #sign = MirnovCoil   #WARNING závisí na aktuálním nastavení cívek!!TODO dát to do externího configu
    #n = size(data, 0)
    #envelope = zeros(shape(data))
    
    #dt = data[0,1]-data[0,0]
    #n_smooth = int(0.002/dt)*2+1  #omezení časového rozlišení pro nízké frekvence 250Hz
    
    #for i in range(1,n): 
	#data[i,:]*= sign[i-1]
	#baseline  = fftconvolve((data[i,:]), ones(n_smooth)/n_smooth,mode = 'same' )
	#data[i,:]-= baseline
	#envelope[i,:] = fftconvolve(abs(data[i,:]), ones(n_smooth/10)/n_smooth/10,mode = 'same' )
	#plot(baseline, label = str(i))
    
    #legend()
    ##show()
    #savefig('data.png')
    #close()
    #total_envelope  = mean(envelope,0)
    #for i in range(1,n):    
	#data[i,:]*= total_envelope/(envelope[i,:]+1e-6)

    #return data
  
      
def plotData(  data):
    dt = data[0,1]-data[0,0]

    n_smooth = 20  #omezení časového rozlišení pro nízké frekvence 250Hz
    n = size(data, 1)
    for i in range(1,size(data, 0)):
	#baseline  = fftconvolve((data[i,:]), ones(n_smooth)/n_smooth,mode = 'same' )
	dfft = fft.rfft(data[i,:])
	dfft[n/16:] = 0
	difft = fft.irfft(dfft)
	#plot(data[0,:], data[i,:])

	plot(data[0,:], difft)
	xlim(0.016, 0.0171)
	ylim(-0.2, 0.2)
    #show()
	
    savefig('./graphs/raw_data.png')
    close()
	
def loadconst(fname):
    with open(fname, 'r') as fhandle:
	return float(fhandle.readline()) #return the raw string   	

	
#def LoadData():
    #Data = Shot()

    #gd = Shot().get_data
    #das, m1  = gd('any', 'mirnov_1', return_channel = True)
    #das, m5  = gd('any', 'mirnov_5', return_channel = True)
    #das, m9  = gd('any', 'mirnov_9', return_channel = True)
    #das, m13 = gd('any', 'mirnov_13', return_channel = True)

    #Papouch = list2array( Data[das, [m1, m5, m9, m13]] ).T 
      
    #plasma = Data['plasma']

    #plasma_start = Data['plasma_start']
    #plasma_end = Data['plasma_end']

    #return Papouch,plasma_start, plasma_end,plasma
    
def LoadData():
    Data = Shot()
#(data, axis=-1, type='linear', bp=0) 
    gd = Shot().get_data
    das, data  = gd('any', 'ring_1', return_channel = True)

    Papouch = list2array( Data[das, range(16)] ).T 
    Papouch[1:,:]*= array(RingCoilOrientation)[:,None]
    Papouch[1:,:]/= array(AEff)[:,None]
    #print 'ploz'
    plasma = Data['plasma']

    plasma_start = Data['plasma_start']
    plasma_end = Data['plasma_end']
    
    
    #plot()
    
    #plot markovicuv profil 
    #X = signal.detrend(Papouch[1:,:], axis=1)
    #X = cumsum(X,axis=1)*1e-6
    
    #t = Papouch[0,:]
    #for i in range(16):
	#plot(t,X[i])
	#savefig('data%d.png'%i)
	#clf()
    #exit()
    #X = X[:,16000:16700]
    #t = t[16000:16700]


    #X = signal.detrend(X, axis=1, type='linear', bp=0)
    #print mad(X,axis=1)/mean(mad(X,axis=1))*array(AEff)

    #exit()
    #det = range(1,16)
    #from scipy.stats.mstats import mquantiles
    #lim = mquantiles(abs(X),0.99)
    #imshow(X, aspect='auto', extent=[t[0]*1e3,t[-1]*1e3,det[-1]+0.5, det[0]+0.5],
	#vmin=-lim, vmax=lim)
    #axis([t[0]*1e3,t[-1]*1e3 ,det[0]+0.5,  det[-1]+0.5])

    #colorbar(format='%.1e')
    #savefig('X.png')
    
    

    return Papouch,plasma_start, plasma_end,plasma
    
        
    
    
#def LoadData(path, shot_num, skip_det_num):
    
    #path += str(shot_num)+'/'
    #try:
	#data = load(path+'data.npy')
	#if size(data,1) == 0:
	    #raise "wrong data"
    #except:
	##print 'reload'
	#data = list()
	#tvec = None
	#for i in range(0,16):
	    #print 'load NIturbo_%2.2i' %(i+1)

	    #if i in skip_det_num:
		#data.append(zeros(shape(data[0])))
		#continue
	    ##single_data =  loadtxt(path+'NIturbo_%2.2i' %(i+1)+'.asc', usecols = (1,))
	    #single_data =  loadtxt(path+'NIturbo_%2.2i' %(i+1), usecols = (1,))
	    
	    #if tvec == None:
		##tvec =  loadtxt(path+'NIturbo_%2.2i' %(i+1)+'.asc', usecols = (0,))	
		#tvec =  loadtxt(path+'NIturbo_%2.2i' %(i+1), usecols = (0,))	

	    #data.append(single_data)
	    
	    
	#data = array(data)
	#if size(data,1) == 0:
	    #raise "wrong data"
	#data = vstack((tvec, data))
	#save(path+'data', single(data))
    
    #try:
	#start = loadconst(path+'PlasmaStart')/1000
	#end   = loadconst(path+'PlasmaEnd'  )/1000
    #except:
	#start = nan
	#end = nan
    #print start,end
    #return data,start,end
	
    
#def LoadData_npz(path, shot_num, skip_det_num):  #FIXME provizorní verze!!
    #data = load('Nidatap.npz')
    ##data  data['data']
    
    #tvec = linspace(data['t_start'], data['t_end'], size(data['data'],0))
    #data = data['data']
    #data[:,skip_det_num] = 0
    #data = vstack((tvec, data.T))

    #start = data[0,8300]
    #end = data[0,19260]
    #return data,start,end

##def LoadData_mirnov(path, shot_num):
    
    ##path += str(shot_num)+'/'
    ##try:
	##data = load(path+'data.npy')
	##if size(data,1) == 0:
	    ##raise "wrong data"
	##if shot_num > 6000 and shot_num < 9500:
	    ##data = data.T
	    ##data[1:,1:] = diff(data[1:,:], axis = 1)/(data[0,1]-data[0,0])
    ##except:
	##data = list()
	##tvec = None
	##for i in range(0,4):
	    ##print 'load PapouchSt_%2.2i' %(i+1)


	    ##single_data =  loadtxt(path+'PapouchSt_%2.2i' %(i+1), usecols = (1,))
	    
	    ##if tvec == None:
		##tvec =  loadtxt(path+'PapouchSt_%2.2i' %(i+1), usecols = (0,))	

	    ##data.append(single_data)
	    
	    
	##data = array(data)
	##if size(data,1) == 0:
	    ##raise "wrong data"
	##data = vstack((tvec, data))
	##save(path+'data', single(data))
    

    ##try:
	##start = loadconst(path+'PlasmaStart')/1000
	##end   = loadconst(path+'PlasmaEnd'  )/1000
    ##except:
	##start = nan
	##end = nan
    ##print start,end
    ##return data,start,end    
    
def PlotSpec((freq, field, t,  sufname ,vmin, vmax)):
    
    fig = figure()
    t = t*1000
    ax = fig.add_axes([0.1, 0.1, 0.8, 0.85])
    ax.set_yscale('log', nonposy='clip')

    img = ax.imshow(abs(field), extent=[t[0],t[-1] ,freq[-1], freq[0]], aspect='auto',vmin=vmin, vmax=vmax) 
    minorLocator   = MultipleLocator(1)
    ax.xaxis.set_minor_locator(minorLocator)
    ax.axis([t[0],t[-1] ,amin(freq), amax(freq)])

    ax.set_xlabel('time [ms]')


    ax.set_ylabel('Frequency [Hz]')
    savefig('graphs/spectrogram'+str(sufname)+'.png',bbox_inches='tight')
    close()
    
    
def PrepareData( t, X):
    #print X.shape
    
    #fig = figure('poloidal M')
    #t = t
    #ax = fig.add_axes([0.1, 0.1, 0.8, 0.85])
   
    X = cumsum(X.T,axis=1)*1e-6

    N = 101

    taps =signal.firwin(N, 1e3/(0.5*1e6) , window='hamming')   

    filtered_x = signal.lfilter(taps, 1.0, X[:,:]-X[:,0][:,None], axis=1)
 
    taps =signal.firwin(N, 1e5/(0.5*1e6) , window='hamming')
    X = X[:,:-N/2]-X[:,0][:,None]- filtered_x[:,(N+1)/2:]
    t = t[:-N/2]
    X = signal.lfilter(taps, 1.0, X, axis=1)
    X = X[:,(N+1)/2:]
    t = t[:-N/2]
    
    #print X.shape
    #exit()

    return  t, X.T
    
    
   
def PlotSignal(t,  sig, tbeg, tend):
    
    ind = (t> tbeg)&(t < tend)
    t = t[ind]*1000
    sig = sig[ind,:]
    sig = signal.detrend(sig, axis=0)
    sig/= std(sig, axis=0)
    
    lim = mquantiles(abs(sig),0.99)

    det = range(1,16)
    imshow(sig.T, aspect='auto', extent=[t[0],t[-1],det[-1]+0.5, det[0]+0.5],
	vmin=-lim, vmax=lim, interpolation='nearest')
    axis([t[0],t[-1] ,det[0]+0.5,  det[-1]+0.5])

    colorbar(format='%.1e')
    savefig('graphs/signal.png')
    clf()
    #exit()
    
def PlotModes(field, t):
    
    fig = figure('poloidal M')
    t = t*1000
    ax = fig.add_axes([0.1, 0.1, 0.8, 0.85])
    print field.shape
    #exit()
    
    X = cumsum(field.T,axis=1)*1e-6
    #t = Papouch[0,:]
    #X = signal.detrend(X, axis=1, type='linear', bp=0)
    print mad(X,axis=1)/mean(mad(X,axis=1))
    #X = X[:,16000:16700]
    #t = t[16000:16700]
    
    #FC = 2e3/(0.5*1e6) 
    #N = 3001                                             # number of filter taps
    #b = signal.firwin(N, cutoff=FC, window='hamming')  
    #b/= mean(b)
    #plot(b)
    #savefig('b.png')
    #X -= signal.lfilter(b, 1, X, axis=0) 
    #print mean(abs(X)), mean(abs(signal.lfilter(b, 1, X, axis=0) ))
    #FC = 5e4/(0.5*1e6) 

    
    
    #b = signal.firwin(N, cutoff=FC, window='hamming')    
    #clf()
    #plot(b)
    #savefig('graphs/f.png')
    #show()
    #X = signal.lfilter(b, 1, X, axis=1) 
    #exit()

  
    
    N = 101

    taps =signal.firwin(N, 1e3/(0.5*1e6) , window='hamming')   
    #plot(taps)
    #savefig('X.png')
    #exit()

    filtered_x = signal.lfilter(taps, 1.0, X[:,:]-X[:,0][:,None], axis=1)

    taps =signal.firwin(N, 1e5/(0.5*1e6) , window='hamming')
    X = X[:,:-N/2]-X[:,0][:,None]- filtered_x[:,(N+1)/2:]
    t = t[:-N/2]
    X = signal.lfilter(taps, 1.0, X, axis=1)
    X = X[:,(N+1)/2:]
    t = t[:-N/2]
    
    
    #exit()
    
    
    #plot(taps)
    #savefig('X.png')
    #exit()

    #taps =signal.firwin(N, 1e3/(0.5*1e6) , window='hamming')     
    #filtered_x = signal.lfilter(taps, 1.0, X[1,:]-X[1,0])
    ##print X[1,:]
    #clf()
    #plot((X[:,:]-X[:,0][:,None]).T)
    #plot(filtered_x[:,N/2:].T)
 
    
    #plot(X.T)
    #savefig('X.png')
    #exit()
    #X = X[:,5300:5700]
    #t = t[5300:5700]
    
    #X = signal.detrend(X, axis=1, type='linear', bp=0)
    #print mad(X,axis=1)/mean(mad(X,axis=1))


    #lim = mquantiles(abs(abs(X)),0.99)

    #det = range(1,16)
    #imshow(X, aspect='auto', extent=[t[0],t[-1],det[-1]+0.5, det[0]+0.5],
	#vmin=-lim, vmax=lim, interpolation='nearest')
    #axis([t[0],t[-1] ,det[0]+0.5,  det[-1]+0.5])

    #colorbar(format='%.1e')
    #savefig('X.png')
    #clf()
    #exit()
    ##fs = 1e6
    ##framesz = 0.001  # with a frame size of 50 milliseconds
    ##hop = 0.0002      # and hop size of 20 milliseconds.
 
    ##print shape(X[10,:])
    ##FX = stft(X[10,:]+sin(arange(len(X[10,:]))/1000)*10000, fs, framesz, hop)
    ##print abs(FX)
    ##print shape(FX)
    ##imshow(log(abs(FX)).T, aspect='auto')
    ##savefig('STFT.png')

    
    #exit()
    
    #imshow(X, aspect='auto')
    #savefig('graphs/X.png')
    
    #field = X
    #print shape(field)
    #exit()
    Fsig = abs(fft.rfft(X, axis = 0))
    Fsig = signal.medfilt(Fsig, [ 1,101])	
    #imshow(abs(Fsig), aspect='auto', vmin = 0, vmax = lim)
    #savefig('X2.png')
    #exit()
    #FC = 1e3/(0.5*1e6) 
    #N = 1001                                             # number of filter taps
    #b = signal.firwin(1001, cutoff=FC, window='hamming')    
    #Fsig = signal.lfilter(b, 1, Fsig, axis=0) 
    n_mod = size(Fsig,0)
    #modes = arange(n_modn_mod)

    Fsig = c_[Fsig[:,(n_mod+1)/2:],Fsig[:,:(n_mod+1)/2] ]
    #print Fsig.shape
    lim = mquantiles(abs(Fsig),0.99)

    img = ax.imshow(Fsig, extent=[t[0],t[-1] ,n_mod/2+0.5, -n_mod/2+0.5], 			aspect='auto', interpolation='nearest',vmin=0, vmax=lim) 
    fig.colorbar(img)
    minorLocator   = MultipleLocator(1)
    ax.xaxis.set_minor_locator(minorLocator)
    ax.yaxis.set_major_locator(minorLocator)
    ax.axis([t[0],t[-1] ,-n_mod/2+0.5, n_mod/2+0.5])

    ax.set_xlabel('time [ms]')
    max_mode = argmax(mean(Fsig,axis=0))-n_mod/2
    #print max_mode-n_mod/2
    saveconst('mode_M', max_mode)
    saveconst('mode_M_abs', max(mean(Fsig,axis=0)) )

    ax.set_ylabel('Poloidal mode number M')
    #savefig('graphs/poloidalM.png')

    savefig('graphs/poloidalM.png',bbox_inches='tight')
    close()   
    #print 'saved'
    exit()
    
    
def CalculateSpectrogram():
      
    print "CalculateSpectrogram"
    
    t1 = time.time()

    signal,plasma_start,plasma_end,plasma = LoadData()
    
    #plasma = False  #BUG!!!!
    #plasma_start =  plasma_start*0.8
    #plasma_end =  min((plasma_end-plasma_start)*1.2+plasma_start,signal[0,-1])   
    #plot(signal.T)
    #savefig('graphs/data.png')
    #clf()
    if not plasma:
	plasma_start = 0
	plasma_end = 0.04
    
    dt = signal[0,1]-signal[0,0]
    n_start = argmin(abs(plasma_start - signal[0,:]))
    n_end   = argmin(abs(plasma_end - signal[0,:]))

    
    t = signal[0,n_start:n_end]
    signal = signal[1:,n_start:n_end].T
    
	#plot(signal)
    #savefig('data.png')
    #clf()
    PlotModes(signal, t)
    exit()
    t, signal = PrepareData( t, signal) 

    PlotSignal(t,  signal, 16.3e-3, 16.7e-3)

    omega0 = 40#20
    horiz_res = 2000
    f_min = 1e3  #Hz
    f_max = 100e4 #Hz
    
    
    # !! searched MHD mdoes !!
    modes =  [-4, -3,-2,-1,0,1,2,3, 4]

    N = size(signal,0)
    n_det = size(signal,1)
    t2 = time.time()
    signal_fft = fft.rfft(signal, axis = 0)
    for m in modes:
	print "wave order", m
	signal_m = copy(signal_fft)
	phase = arange(n_det)/double(n_det)*m 
	exp_phase = matrix(exp(-2*pi*1j*phase))
	signal_m*= exp_phase
	signal_m = sum(signal_m, axis = 1)

    
	signal_m = fft.irfft(signal_m)
	try:
	  soundGenerator(signal_m,2000, 'mp3/sound'+str(m))
	except Exception, e:
	  print "sound gener. failed err:" , e.message
	  
    print 'generated sound', time.time()-t2


    #modes =  [2,]

    spec_all = list()
    scale_all = list()
	
    print "started wavelets"
    #out = map(NTM_CWT, [ (signal, dt, 0.005,omega0,m ,horiz_res,  f_min,f_max) for m in modes ])

    p = Pool(cpu_count())
    out = p.map(NTM_CWT, [ (signal, dt, 0.005,omega0,m ,horiz_res,  f_min,f_max) for m in modes ])
    p.close()
    p.join()

    for i in range(len(modes)):
	spec, scale = out[i]
	spec = single(abs(spec))
	spec_all.append(spec)
	scale_all.append(scale)
	
    print 'calc time',  time.time()-t1


    t_plot = time.time()
    spec_all  = array(spec_all)
    scale_all = array(scale_all)

    freq = (omega0 + sqrt(2.0 + omega0**2))/(4*pi * scale_all)
    contrast = 10
    spec_all = log(1+contrast*abs(spec_all))
    vmin = amin(spec_all)
    vmax = amax(spec_all)
    p = Pool(cpu_count())

    p.map( PlotSpec, [(freq[i,...],spec_all[i,...],t, m,vmin, vmax) for i,m in enumerate(modes) ])
    p.close()
    p.join()

    print 'plot. time', time.time()-t_plot




def main():
    
    for path in ['graphs', 'mp3']:
	if not os.path.exists(path):
	    os.mkdir(path)
	    

    if sys.argv[1] ==  "plots":
	CalculateSpectrogram()
	os.system('convert -resize 150x120\! graphs/spectrogram0.png icon.png')
	saveconst('status', 0)





if __name__ == "__main__":
    main()
   