#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib
matplotlib.rcParams['backend'] = 'Agg'
matplotlib.rc('font',  size='10')
matplotlib.rc('text', usetex=True)  # FIXME !! nicer but slower !!!
from matplotlib.pyplot import *
import os
#import  scipy.signal
from scipy import signal
from scipy.signal import fftconvolve
from multiprocessing import Process, Pool, cpu_count
import time
from numpy import *
from CMWT import *
from SoundGenerator import *
from pygolem_lite import Shot
from pygolem_lite.modules import list2array,deconvolveExp,save_adv,saveconst
from scipy.stats.mstats  import mquantiles


#RingCoilOrientation = (1,1,-1,-1, 1,-1,1,1,1,-1,-1,1,1,1,1,-1) #TODO důležité
#RingCoilOrientation = (-1,1,1,1,-1,1,-1,-1,-1,1,-1,1,-1,-1,-1,-1) #TODO důležité
RingCoilOrientation = (-1,1,1,1,-1,1,-1,-1,-1,1,-1,1,-1,-1,-1,-1)

AEff = (68.93e-4, 140.68e-4, 138.83e-4, 140.43e-4, 148.59e-4, 134.47e-4,134.28e-4, 142.46e-4, 67.62e-4, 142.80e-4, 140.43e-4, 138.02e-4, 76.32e-4, 142.18e-4, 139.82e-4, 139.33e-4) #in m^2
AEff = [ 0.00209134 , 0.00446023 , 0.00428289  ,0.00374073 , 0.00320856 , 0.00182771,
  0.00267794,  0.00284156 , 0.00145525  ,0.00222005  ,0.00206269 , 0.00108452,
  0.00054154 , 0.00078395 , 0.00186135 , 0.00274475]
MirnovCoil = (1,1,1,1)#TODO důležité


matplotlib.rcParams['xtick.direction'] = 'out'
matplotlib.rcParams['ytick.direction'] = 'out'

matplotlib.rcParams['xtick.major.size'] = 10
matplotlib.rcParams['xtick.minor.size'] = 7

matplotlib.rcParams['ytick.major.size'] = 10
matplotlib.rcParams['ytick.minor.size'] = 7



def PrepareData( t, X):
    #print X.shape
    
    #fig = figure('poloidal M')
    #t = t
    #ax = fig.add_axes([0.1, 0.1, 0.8, 0.85])
   
    X = cumsum(X.T,axis=1)*1e-6
    #exit()
    N = 101

    taps =signal.firwin(N, 1e3/(0.5*1e6) , window='hamming')   

    filtered_x = signal.lfilter(taps, 1.0, X[:,:]-X[:,0][:,None], axis=1)
 
    taps =signal.firwin(N, 1e5/(0.5*1e6) , window='hamming')
    X = X[:,:-N/2]-X[:,0][:,None]- filtered_x[:,(N+1)/2:]
    t = t[:-N/2]
    X = signal.lfilter(taps, 1.0, X, axis=1)
    X = X[:,(N+1)/2:]
    t = t[:-N/2]
    
    #print X.shape
    #exit()

    return  t, X.T
    
  
      
def plotData(  data):
    dt = data[0,1]-data[0,0]

    n_smooth = 20  #omezení časového rozlišení pro nízké frekvence 250Hz
    n = size(data, 1)
    for i in range(1,size(data, 0)):
	#baseline  = fftconvolve((data[i,:]), ones(n_smooth)/n_smooth,mode = 'same' )
	dfft = fft.rfft(data[i,:])
	dfft[n/16:] = 0
	difft = fft.irfft(dfft)
	#plot(data[0,:], data[i,:])

	plot(data[0,:], difft)
	xlim(0.016, 0.0171)
	ylim(-0.2, 0.2)
    #show()
	
    savefig('./graphs/raw_data.png')
    close()
	
def loadconst(fname):
    with open(fname, 'r') as fhandle:
	return float(fhandle.readline()) #return the raw string   	

	
#def LoadData():
    #Data = Shot()

    #gd = Shot().get_data
    #das, m1  = gd('any', 'mirnov_1', return_channel = True)
    #das, m5  = gd('any', 'mirnov_5', return_channel = True)
    #das, m9  = gd('any', 'mirnov_9', return_channel = True)
    #das, m13 = gd('any', 'mirnov_13', return_channel = True)

    #Papouch = list2array( Data[das, [m1, m5, m9, m13]] ).T 
      
    #plasma = Data['plasma']

    #plasma_start = Data['plasma_start']
    #plasma_end = Data['plasma_end']

    #return Papouch,plasma_start, plasma_end,plasma
    
def LoadData():
    Data = Shot()
#(data, axis=-1, type='linear', bp=0) 
    gd = Shot().get_data
    das, data  = gd('any', 'ring_1', return_channel = True)
   
    
    Papouch = list2array( Data[das, range(16)] ).T 
    #print shape(Papouch)
    
    
    plasma = Data['plasma']

    plasma_start = Data['plasma_start']
    plasma_end = Data['plasma_end']
    
    Papouch[1:,:]*= array(RingCoilOrientation)[:,None]
    Papouch[1:,:]/= array(AEff)[:,None]
    
    tvec = Papouch[0,:]
    #print 'ploz'
    ind = (plasma_start-1e-3 <tvec)&( plasma_end+1e-3 >tvec )
    X = Papouch[1:,ind]
    #print X.shape
    X = signal.detrend(X, axis=1)

    X = cumsum(X,axis=1).T*1e-6
    X-= median(X,axis=1)[:,None]
    X/= std(X)
    X += arange(16)[None,:]
    plot(tvec[ind]*1000, X,linewidth=.3)
    axis('tight')
    xlabel('t [s]')
    ylabel('Coil number')

    savefig('./graphs/raw_data.png')
    
    clf()
    #exit()
    #exit()
    #show()
    ##plot(cumsum(Papouch[1:,18000:21000],axis=1).T)
    #Y = cumsum(Papouch[1:,17000:22000],axis=1)
    #Y = signal.detrend(Y, axis=1, type='linear', bp=0) 
    #Y /= std(Y,axis=1)[:,None]
    #tvec = Papouch[0,17000:22000]
    #FC = 1e4/(0.5*1e6) 
    #N = 1001                                             # number of filter taps
    #b = signal.firwin(N, cutoff=FC, window='hamming')    
    #Y = signal.lfilter(b, 1, Y, axis=1) 
    
    
    #FC = 1e3/(0.5*1e6) 
    #N = 1001                                             # number of filter taps
    #b = signal.firwin(N, cutoff=FC, window='hamming')    
    #Y2 = signal.lfilter(b, 1, Y, axis=1) 
    #Y-=Y2
    
    
    ##plot(b)
    ##savefig('fir.png')
    ##clf()
    
    
    
    
    #M = (2+len(tvec))/2
    #f = arange(M)*1e6/2/M
    #print f.shape
    ##print abs(fft.rfft(Y,axis=1)).shape
    #plot(Y.T)
    ##loglog(f, abs(fft.rfft(Y,axis=1)).T)
    
    ##ylim(1,None)
    ##xlim(1e4, 1e6/2)
    #savefig('data_raw.png')
    #clf()
    #pcolor(Y)
    #savefig('data.png')
    ##show()
    #RingCoilOrientation
    #exit()



    return Papouch,plasma_start, plasma_end,plasma
    
        
    
    
#def LoadData(path, shot_num, skip_det_num):
    
    #path += str(shot_num)+'/'
    #try:
	#data = load(path+'data.npy')
	#if size(data,1) == 0:
	    #raise "wrong data"
    #except:
	##print 'reload'
	#data = list()
	#tvec = None
	#for i in range(0,16):
	    #print 'load NIturbo_%2.2i' %(i+1)

	    #if i in skip_det_num:
		#data.append(zeros(shape(data[0])))
		#continue
	    ##single_data =  loadtxt(path+'NIturbo_%2.2i' %(i+1)+'.asc', usecols = (1,))
	    #single_data =  loadtxt(path+'NIturbo_%2.2i' %(i+1), usecols = (1,))
	    
	    #if tvec == None:
		##tvec =  loadtxt(path+'NIturbo_%2.2i' %(i+1)+'.asc', usecols = (0,))	
		#tvec =  loadtxt(path+'NIturbo_%2.2i' %(i+1), usecols = (0,))	

	    #data.append(single_data)
	    
	    
	#data = array(data)
	#if size(data,1) == 0:
	    #raise "wrong data"
	#data = vstack((tvec, data))
	#save(path+'data', single(data))
    
    #try:
	#start = loadconst(path+'PlasmaStart')/1000
	#end   = loadconst(path+'PlasmaEnd'  )/1000
    #except:
	#start = nan
	#end = nan
    #print start,end
    #return data,start,end
	
    
#def LoadData_npz(path, shot_num, skip_det_num):  #FIXME provizorní verze!!
    #data = load('Nidatap.npz')
    ##data  data['data']
    
    #tvec = linspace(data['t_start'], data['t_end'], size(data['data'],0))
    #data = data['data']
    #data[:,skip_det_num] = 0
    #data = vstack((tvec, data.T))

    #start = data[0,8300]
    #end = data[0,19260]
    #return data,start,end

##def LoadData_mirnov(path, shot_num):
    
    ##path += str(shot_num)+'/'
    ##try:
	##data = load(path+'data.npy')
	##if size(data,1) == 0:
	    ##raise "wrong data"
	##if shot_num > 6000 and shot_num < 9500:
	    ##data = data.T
	    ##data[1:,1:] = diff(data[1:,:], axis = 1)/(data[0,1]-data[0,0])
    ##except:
	##data = list()
	##tvec = None
	##for i in range(0,4):
	    ##print 'load PapouchSt_%2.2i' %(i+1)


	    ##single_data =  loadtxt(path+'PapouchSt_%2.2i' %(i+1), usecols = (1,))
	    
	    ##if tvec == None:
		##tvec =  loadtxt(path+'PapouchSt_%2.2i' %(i+1), usecols = (0,))	

	    ##data.append(single_data)
	    
	    
	##data = array(data)
	##if size(data,1) == 0:
	    ##raise "wrong data"
	##data = vstack((tvec, data))
	##save(path+'data', single(data))
    

    ##try:
	##start = loadconst(path+'PlasmaStart')/1000
	##end   = loadconst(path+'PlasmaEnd'  )/1000
    ##except:
	##start = nan
	##end = nan
    ##print start,end
    ##return data,start,end    
    
def PlotSpec((freq, field, t,  sufname ,vmin, vmax)):
    
    
    if sufname == 0:
	vmax = None
    fig = figure()
    t = t*1000
    ax = fig.add_axes([0.1, 0.1, 0.8, 0.85])
    ax.set_yscale('log', nonposy='clip')

    img = ax.imshow(abs(field), extent=[t[0],t[-1] ,freq[-1], freq[0]], aspect='auto',vmin=vmin, vmax=vmax) 
    minorLocator   = MultipleLocator(1)
    ax.xaxis.set_minor_locator(minorLocator)
    ax.axis([t[0],t[-1] ,amin(freq), amax(freq)])

    ax.set_xlabel('time [ms]')


    ax.set_ylabel('Frequency [Hz]')
    savefig('graphs/spectrogram'+str(sufname)+'.png',bbox_inches='tight')
    close()
    
    
def PlotModes(field, t):
    
    fig = figure('poloidal M')
    t = t*1000
    ax = fig.add_axes([0.1, 0.1, 0.8, 0.85])
    #print field.shape
    Fsig = abs(fft.rfft(field, axis = 1))
    FC = 1e3/(0.5*1e6) 
    N = 1001                                             # number of filter taps
    b = signal.firwin(N, cutoff=FC, window='hamming')    
    Fsig = signal.lfilter(b, 1, Fsig, axis=0) 
    n_mod = size(Fsig,1)
    #modes = arange(n_modn_mod)
    Fsig[:,0] = 0
    Fsig*= 1e4  #make a resonable size of values
    #lim = amax(Fsig[:,1:])

    #Fsig = c_[Fsig[:,(n_mod+1)/2:],Fsig[:,:(n_mod+1)/2] ]
    Fsig = fft.fftshift(Fsig, axes=1)
    #print Fsig.shape
    img = ax.imshow(Fsig.T, extent=[t[0],t[-1] ,n_mod/2+0.5, -n_mod/2+0.5], aspect='auto', interpolation='nearest',vmin=0) 
    fig.colorbar(img)
    minorLocator   = MultipleLocator(1)
    ax.xaxis.set_minor_locator(minorLocator)
    ax.yaxis.set_major_locator(minorLocator)
    ax.axis([t[0],t[-1] ,-n_mod/2+0.5, n_mod/2+0.5])

    ax.set_xlabel('time [ms]')
    max_mode = argmax(mean(Fsig,axis=0))-n_mod/2
    #print max_mode-n_mod/2
    saveconst('mode_M', max_mode)
    saveconst('mode_M_abs', max(mean(Fsig,axis=0)) )

    ax.set_ylabel('Poloidal mode number M')
    savefig('graphs/poloidalM.png',bbox_inches='tight')
    close()   
    #print 'saved'
    #exit()
    
    
def CalculateSpectrogram():
      
    print "CalculateSpectrogram"
    
    t1 = time.time()

    signal,plasma_start,plasma_end,plasma = LoadData()
    
    #plasma = False  #BUG!!!!
    #plasma_start =  plasma_start*0.8
    #plasma_end =  min((plasma_end-plasma_start)*1.2+plasma_start,signal[0,-1])   
    
    if not plasma:
	plasma_start = 0
	plasma_end = 0.04
    
    dt = signal[0,1]-signal[0,0]
    n_start = argmin(abs(plasma_start - signal[0,:]))
    #n_start = int((plasma_start - signal[0,0])/dt)
    #n_end= int((plasma_end -signal[0,0] )/dt)
    n_end= argmin(abs(plasma_end - signal[0,:]))
    t = signal[0,:]
    signal = signal[1:,:]
    #print signal.shape

    
    t, signal = PrepareData(t, signal.T)
    #print signal.shape

    t = t[n_start:n_end]
    signal = signal[n_start:n_end,:]
    #print signal.shape
    #exit()
    #plot(signal)
    #savefig('data.png')
    #clf()
    PlotModes(signal, t)

    omega0 = 40#20
    horiz_res = 2000
    f_min = 1e3  #Hz
    f_max = 200e4 #Hz
    
    
    # !! searched MHD mdoes !!
    modes =  array([-4, -3,-2,-1,0,1,2,3, 4])

    N = size(signal,0)
    n_det = size(signal,1)
    t2= time.time()
    signal_fft = fft.rfft(signal, axis = 0)
    for m in modes:
	print "wave order", m
	signal_m = copy(signal_fft)
	phase = arange(n_det)/double(n_det)*m 
	exp_phase = matrix(exp(-2*pi*1j*phase))
	signal_m*= exp_phase
	signal_m = sum(signal_m, axis = 1)

    
	signal_m = fft.irfft(signal_m)
	try:
	  soundGenerator(signal_m,2000, 'mp3/sound'+str(m))
	except Exception, e:
	  print "sound gener. failed err:" , e.message
	  
    print 'generated sound', time.time()-t2


    #modes =  [2,]

    spec_all = list()
    scale_all = list()
	
    print "started wavelets"
    #out = map(NTM_CWT, [ (signal, dt, 0.005,omega0,m ,horiz_res,  f_min,f_max) for m in modes ])

    p = Pool(cpu_count())
    out = p.map(NTM_CWT, [ (signal, dt, 0.005,omega0,m ,horiz_res,  f_min,f_max) for m in modes ])
    p.close()
    p.join()

    for i in range(len(modes)):
	spec, scale = out[i]
	spec = single(abs(spec))
	spec_all.append(spec)
	scale_all.append(scale)
	
    print 'calc time',  time.time()-t1


    t_plot = time.time()
    spec_all  = array(spec_all)
    scale_all = array(scale_all)

    freq = (omega0 + sqrt(2.0 + omega0**2))/(4*pi * scale_all)
    contrast = 10
    spec_all = log(1+contrast*abs(spec_all))
    vmin = 0
    vmax = amax(spec_all[modes!= 0,:])
    p = Pool(cpu_count())

    p.map( PlotSpec, [(freq[i,...],spec_all[i,...],t, m,vmin, vmax) for i,m in enumerate(modes) ])
    p.close()
    p.join()

    print 'plot. time', time.time()-t_plot




def main():
    
    for path in ['graphs', 'mp3']:
	if not os.path.exists(path):
	    os.mkdir(path)
	    

    if sys.argv[1] ==  "plots":
	CalculateSpectrogram()
	os.system('convert -resize 150x120\! graphs/spectrogram0.png icon.png')
	saveconst('status', 0)





if __name__ == "__main__":
    main()
   