#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib 

#matplotlib.rcParams['backend'] = 'TkAgg'
#matplotlib.rcParams['backend'] = 'WXAgg'
#matplotlib.rcParams['backend'] = 'GTKAgg'
#matplotlib.rcParams['backend'] = 'QtAgg'
##matplotlib.rcParams['backend'] = 'Qt4Agg'
matplotlib.rcParams['backend'] = 'Agg'

from matplotlib.pyplot import *
from numpy import *
import os
from scipy import signal
from scipy.signal import fftconvolve
from multiprocessing import Process, Pool, cpu_count
import time
from scipy.stats.mstats  import mquantiles
import matplotlib.animation as manimation
import re
from CWT import cwt
from fftshift import fftshift
from SoundGenerator import soundGenerator
from numexpr import evaluate


animation = False

RingCoilOrientation = {12000:(-1,1,1,1,-1,1,-1,-1,-1,1,-1,1,-1,-1,-1,-1),
 0:(-1,-1,1,1,-1,1,-1,1,-1,1,-1,-1,-1,-1,-1,-1)}

MirnovCoilOrientation = array((-1,-1,1,1))
MirnovEffectiveArea = array((1,1,1,1))


EffectiveArea = {12000:[ 0.00209134 , 0.00446023 , 0.00428289  ,0.00374073 , 0.00320856 , 0.00182771,
  0.00267794,  0.00284156 , 0.00145525  ,0.00222005  ,0.00206269 , 0.00108452,
  0.00054154 , 0.00078395 , 0.00186135 , 0.00274475],
  1:[ 0.00209134 , 0.00446023 , 0.00428289  ,0.00374073 , 0.00320856 , 0.00252771,
  0.00267794,  0.00284156 , 0.00145525  ,0.00222005  ,0.00206269 , 0.00005452,
  0.00254154 , 0.00078395 , 0.00186135 , 0.00274475],
   0:[    2.45766087e-03  , 3.33461911e-03 ,  3.21871869e-03 ,  4.5e-03,
   2.39870295e-03  , 3.99826880e-03  , 3.47274534e-03 ,  3.93465853e-03,
   2.25236901e-03  , 3.22086160e-03 ,  3.27836414e-03 ,  6.87539836e-05,
   3.77966762e-03  , 1.78949877e-03  , 3.19212003e-03  , 3.50243652e-03]}



def smooth(x,window_len=11,window='hanning',axis=-1):
    if not x.ndim in (1,2):
	    raise ValueError, "smooth only accepts 1 or 2 dimension arrays."
    if x.shape[axis] < window_len:
	    raise ValueError, "Input vector needs to be bigger than window size."
    if window_len < 3:
	    return x
	
    x = atleast_2d(x)
    x = swapaxes(x, -1, axis)
    y = empty_like(x)
    
    
    if isinstance(window, str) or type(window) is tuple:
        w = signal.get_window(window, window_len)
    else:
        w = np.asarray(window)
        if len(w.shape) != 1:
            raise ValueError('window must be 1-D')
        if  w.shape[0] > x.shape[-1]:
            raise ValueError('window is longer than x.')
        window_len = w.shape[0]
    
    
    for i in range(x.shape[0]):
	s=r_[2*x[i,0]-x[i,window_len-1::-1],x[i,:],2*x[i,-1]-x[i,-1:-window_len:-1]]
	y[i,:]=fftconvolve(s,w/w.sum(),mode='same')[window_len:-window_len+1]     
    
    y = swapaxes(y, -1, axis)

    return squeeze(y)



class LogFormatterTeXExponent(LogFormatter, object):
    """Extends pylab.LogFormatter to use 
    tex notation for tick labels."""
    
    def __init__(self, *args, **kwargs):
        super(LogFormatterTeXExponent, 
              self).__init__(*args, **kwargs)
        
    def __call__(self, *args, **kwargs):
        """Wrap call to parent class with 
        change to tex notation."""
        label = super(LogFormatterTeXExponent, 
                      self).__call__(*args, **kwargs)
	label = label.replace('-','')
	
        x = float(label)
        if abs(log10(x)) >2:                      
            label = re.sub(r'e(\S)0?(\d+)',r'\\cdot 10^{\1\2}',str(label))
            label = "$" + label + "$"
        else:
            n = max(0,-int(log10(x)-1))
            label = ('$%.'+str(n)+'f$')%x

        return label
	


def mad(x,axis = -1):
    #Median absolute deviation
    m = np.median(x,axis=axis)
    x = np.swapaxes(x, 0, axis)
    s = np.median(abs(x-m),axis=0)*1.4826
    return s

def DFT(x, y,nf):
    
    Ftheta = arange(-nf/2,nf/2)

    E = exp(-1j*outer(Ftheta,x))
    Fx_ = dot(y,conj(linalg.pinv(E)))*nf
    
    return  Fx_




def LoadSignal2():
    from pygolem_lite import Shot
    from pygolem_lite.modules import list2array,deconvolveExp,save_adv,saveconst


    Data = Shot()
    gd = Shot().get_data
    das, data  = gd('any', 'ring_1', return_channel = True)
    Ring = list2array( Data[das, range(16)] )
    
    das, m1  = gd('any', 'mirnov_1', return_channel = True)
    das, m5  = gd('any', 'mirnov_5', return_channel = True)
    das, m9  = gd('any', 'mirnov_9', return_channel = True)
    das, m13 = gd('any', 'mirnov_13', return_channel = True)

    shot  = Data.shot_num
    
    Mirnov = list2array( Data[das, [m1, m5, m9, m13]] )

    tvec,Ring = Ring[:,0],Ring[:,1:]
    tvec,Mirnov = Mirnov[:,0],Mirnov[:,1:]
    data = c_[Ring, Mirnov]/r_[ones(16),ones(4)*0.0005][None,:]

    
    
    plasma = Data['plasma']

    tstart = Data['plasma_start']
    tend = Data['plasma_end']
    tvec, data = PrepareSignal(tvec, data,tstart,tend,shot)
    return tvec, data 


def LoadSignal(shot):
    print 'loading'
    from golem_data import golem_data


    obj = golem_data(shot, 'niturbo')
    data = obj.data
    tvec = obj.tvec
    tend   = golem_data(shot, 'plasma_end').data
    tstart = golem_data(shot, 'plasma_start').data

    data2 = golem_data(shot, 'nistandard6132').data
    data = c_[data, data2]/r_[ones(16),ones(4)*0.0005][None,:]
    
    tvec, data = PrepareSignal(tvec, data,tstart,tend,shot)
    return tvec, data 
    
	
def PrepareSignal(tvec, data,tstart, tfinish,shot ):
    print 'preparing'

    if shot > 12000:
	AEff = r_[array(EffectiveArea[12000]),MirnovEffectiveArea]
	polarity = r_[array(RingCoilOrientation[12000]),MirnovCoilOrientation]
    else:
	AEff = r_[array(EffectiveArea[0]),MirnovEffectiveArea]
	polarity = r_[array(RingCoilOrientation[0]),MirnovCoilOrientation]


    data/= (AEff*polarity)[None,:]
    data = signal.detrend(data, axis=0)

	
    data = cumsum(data, axis=0)*1e-6

    N = 101
    taps = signal.firwin(N, 1e3/(0.5*1e6) , window='hamming', pass_zero=True)   
    data -= smooth(data,window_len=101,window=taps,axis=0)

    nstart = argmin(abs(tvec-tstart))
    nend = argmin(abs(tvec-tfinish))
    
    data = data[nstart:nend,:]
    data = signal.detrend(data, axis=0)

    tvec = tvec[nstart:nend]

    return tvec, data


def GenerateSound(mode_signal,modes):

  
    #use only one "polarization" of the signal
    mode_signal = real(mode_signal)
    
    path = './mp3/'
    if not os.path.exists(path): os.makedirs(path)
    
    
    t2= time.time()
    
    pool = Pool(cpu_count())
    out = pool.map(soundGenerator,[(mode_signal[:,i],2000, path+'sound'+str(m)) for i,m in modes])
    pool.close()
    pool.join()


    print 'generated sound', time.time()-t2


def PlotSpecrograms(freq, field, t, modes, logscale=True):
    print 'PlotSpecrograms'
    
    field = log(1+(field/std(field))**2)/2
    
    vmin = amin(field)
    vmax = amax(field)
    
    fig = figure('spectrogram')

    ax = fig.add_axes([0.1, 0.1, 0.8, 0.85])
    if logscale:
	ax.set_yscale('log', nonposy='clip')
	ax.yaxis.set_minor_formatter(LogFormatterTeXExponent(base=10,
	    labelOnlyBase=False))        
	ax.yaxis.set_major_formatter(LogFormatterTeXExponent(base=10,
	    labelOnlyBase=False))
	    
    img = ax.imshow(zeros((1,1)), extent=[t[0]*1e3,t[-1]*1e3 ,freq[-1]/1e3, freq[0]/1e3], aspect='auto',vmin=vmin, vmax=vmax,interpolation='bicubic') 
    minorLocator   = MultipleLocator(1)

    ax.xaxis.set_minor_locator(minorLocator)
    ax.axis([t[0]*1e3,t[-1]*1e3 ,amin(freq)/1e3, amax(freq)/1e3])
    ax.set_xlabel('time [ms]')
    ax.set_ylabel('Frequency [kHz]')

    t = time.time()
    for i,m in modes:
	#print 'i:%d'%i
	img.set_data((field[...,size(field,2)-i-1]))
	fig.savefig('./graphs/spectrogram'+str(m)+'.png')
	
    print 'plotted spectograms', time.time()-t

    fig.clf()
    

def AnalyzeSpectrograms(spectrogram,tvec,fvec,fmin,fmax,log_scale=True):
    print 'AnalyzeSpectrograms'
    #max_freq = fvec[argmax(spectrogram, axis=0)]
    ind_freq = argmax(spectrogram, axis=0)

    ind_freq = int_(signal.medfilt(ind_freq, (31,1)))
    max_freq = fvec[ind_freq]

    max_amplitude = amax(spectrogram, axis=0)
 
    fig = figure(figsize=(8,8))
    subplots_adjust(hspace=0.05, wspace = 0)

    ax = fig.add_subplot(211)
    tvec_spec = linspace(tvec[0],tvec[-1], size(spectrogram,1))*1e3
    plots = [ax.plot(tvec_spec,max_amplitude[:,i]*10)[0] for i in range(5,1,-1)]
    leg = ax.legend(plots, ['M=%d'%i for i in range(-2,-6,-1)], loc='best', fancybox=True)
    leg.get_frame().set_alpha(0.5)
    ax.set_ylabel('Aplitude [a.u.]')
    ax.xaxis.set_major_formatter(NullFormatter() )
    ax.set_xlim(tvec[0]*1e3,tvec[-1]*1e3)
    
    for label in leg.get_lines():
	label.set_linewidth(3)
    ax = fig.add_subplot(212)
    
    if log_scale:
	ax.set_yscale('log', nonposy='clip')	
	ax.yaxis.set_minor_formatter(LogFormatterTeXExponent(base=10,
	    labelOnlyBase=False))    
	ax.yaxis.set_major_formatter(LogFormatterTeXExponent(base=10,
	    labelOnlyBase=False))  
	    
    plots = [ax.plot(tvec_spec,max_freq[:,i]/1e3,linewidth=1)[0] for i in range(5,1,-1)]
    leg = ax.legend(plots, ['M=%d'%i for i in range(-2,-6,-1)], loc='best', fancybox=True)
    leg.get_frame().set_alpha(0.5)
    
    for label in leg.get_lines():
	label.set_linewidth(3)

    ax.set_ylim((fmin-1)/1e3,(fmax+1)/1e3)

    ax.set_xlabel('t [ms]')
    ax.set_ylabel('Frequency [kHz]')
    ax.set_xlim(tvec[0]*1e3,tvec[-1]*1e3)
    
    fig.savefig('graphs/analyze.png')#,bbox_inches='tight'
    close()
   


def  AnimateSpectrogram(x,spectrogram,fvec,tvec,theta,nmods,
	    nch,framesamp,hopsamp,log_scale=True):

    if not animation:
	return 
    #plot animation 
    nt = size(x,0)

    
    fig = figure(figsize=(10,5))
    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)
    

    mag_img = ax1.imshow(zeros((framesamp,nmods/2+1 )), 
	extent=[fvec[0]/1e3,fvec[-1]/1e3,-nmods/2+0.5,nmods/2+0.5],aspect='auto', 
	    animated=True,interpolation='nearest')

    mag_img.set_cmap('YlOrBr')
  
    plot2, = ax2.plot([],[], linewidth=0.3, animated=False)

    ax2.set_ylim(-0.3,2*pi+0.1)
    ax2.set_ylabel('poloidal angle [rad]')

    ax2.axvline(x=0)

    Dt = tvec[framesamp]-tvec[0]
    ax2.set_xlim(-Dt/2*1e6, Dt/2*1e6)
    ax2.set_xlabel('$\Delta$t [$\mu$s]')
    ax1.set_xlabel('f [kHz]')

    time_text = ax2.text(0.05, 0.9, '', transform=ax2.transAxes,animated=False)
    
    ax1.yaxis.set_major_locator(MultipleLocator(1))
    ax2.xaxis.grid(color='r', linestyle='-', linewidth=.2)
    ax1.set_ylabel('Poloidal mode number M')
    
    if log_scale:
	ax1.set_xscale('log', nonposy='clip')
	ax_format = LogFormatterTeXExponent(base=10,labelOnlyBase=False)
	ax1.xaxis.set_minor_formatter(ax_format)        
	ax1.xaxis.set_major_formatter(ax_format)    
	    
    
    def init():
	plot2.set_data([],[])
	return plot2, 

	
    t = 0
    
    FFMpegWriter = manimation.writers['mencoder']
    metadata = dict(title='NTM', artist='Matplotlib', comment='Animated CWT spectrogram')
    writer = FFMpegWriter(fps=10, metadata=metadata,codec='mpeg4',bitrate=2000)

    with writer.saving(fig, 'graphs/AnimCWT.mp4', 100):
        for j in xrange(0, size(x,0)-framesamp, hopsamp):

            i_spec = ((j+framesamp/2)*size(spectrogram,1))/nt
            
            time_text.set_text('i=%d t=%.2fms'%(j+framesamp/2,1e3*tvec[j+framesamp/2]))

            y = copy(x[j:j+framesamp, : ])
            y/= mad(y,axis=0)[None,:]


            image = spectrogram[:,i_spec,:].T

            image = hypot(image/std(image), 1)
            
            mag_img.set_data(image)

            mag_img.set_clim(1,amax(image[nmods/2+2:,:]))
            T = tile(tvec[:framesamp ]-tvec[:framesamp].mean(),(1,nch)).T
            T[::framesamp] = nan
            
            plot2.set_data(T*1e6, (y/8+theta[None,:]).T.ravel())
            
            fps = 1/(time.time()-t)
            t = time.time()

            sys.stdout.write('\r  plotting %2.1f%%  fps:%2.1f'%(j*100./nt,fps))
            sys.stdout.flush()
            writer.grab_frame() 

      
    
    

def  multi_cwt(x, fs, framesz, hop,tvec=None):
    
    
    framesamp = int(framesz*fs)
    hopsamp = int(hop*fs)
    if tvec == None:
	tvec = arange(size(x,0))/fs
	
	
    nmods = 16
    nt, nch =  shape(x)
   
    fmax = 8e4
    fmin = 3e3
    dj = 0.05  #frequency resolution of CWT
    p = 10    #length of the wavelet
    
    #calculate intensity envelope ma moving STD
    filtered_std = sqrt(abs(smooth(x**2,window_len=1e3,window='hanning',axis=0)))

    x /= filtered_std/mean(filtered_std,axis=1)[:,None]

   
    #poloidal angles of the coils in the !magnetics! coordinates
    theta = load('phi.npy')
    theta = median(theta,axis=0)
    theta += 2*pi*r_[arange(16)/16.,arange(4)/4.]
  
    #DFT over the channels with nonequidistant sampling
    fx = DFT(theta, x,nmods)
    
    print 'calc spectrograms'
  

    pool = Pool(cpu_count())
    out = pool.map(cwt,[(fx[:,i], 1/fs, dj, p ,nt/hopsamp, fmin,fmax) for i in range(nmods)])
    pool.close()
    pool.join()
    freq = out[0][2][::-1]
    
    spectrogram = [abs(out.pop()[0]) for i in range(nmods)]
    spectrogram = dstack(spectrogram)
    
    modes= zip(arange(nmods), arange(-nmods/2, nmods/2))
    
    #create sound of the modes
    GenerateSound(fx,modes)
    
    #analyze the spectrograms
    AnalyzeSpectrograms(spectrogram[::-1,...],tvec,freq,fmin,fmax )
    spectrogram = spectrogram[::-1,:,::-1]

    #plot spectrograms
    PlotSpecrograms(freq,spectrogram,tvec,modes)
  
    #plot animated 3d spectrograms
    AnimateSpectrogram(x,spectrogram,freq,tvec,theta,nmods,nch,framesamp,hopsamp)
 
    
def calc_stft((y,w,nfmax,theta)):
    y = y/mad(y,axis=0)[None,:]
    fy = DFT(theta, y,16)
    f = fft.fft(fy*w[:,None], axis=0)
    return f[:nfmax,:]	 
	
	
	
def stft(x, fs, framesz, hop,tvec=None):
    print 'stft'
    nmods = 16
    fmax = 8e4
    nt, nch =  x.shape

    framesamp = int(2**np.ceil(np.log2(framesz*fs)))
    hopsamp = int(hop*fs)
    w = hamming(framesamp)
   
    if tvec == None:
	tvec = arange(size(x,0))/fs
 
    nfmax = int(fmax/fs*framesamp)
    theta = load('phi.npy')

    theta  = median(theta,axis=0)
    theta += 2*pi*r_[arange(16)/16.,arange(4)/4.]
    	

    pool = Pool(cpu_count())
    ind = xrange(0, size(x,0)-framesamp, hopsamp)

    out = pool.map(calc_stft,[(x[i:i+framesamp,:],w,nfmax,theta) for i in ind])
    pool.close()
    pool.join()
    
    spectrogram = [abs(out.pop()) for i in ind]
    spectrogram = dstack(spectrogram)

    spectrogram = swapaxes(dstack(spectrogram),0,2)
    fvec = arange(framesamp)*fs/(2.*framesamp)
    modes= zip(arange(nmods), arange(-nmods/2, nmods/2))

    AnalyzeSpectrograms( spectrogram,tvec,fvec,0,fmax)
    PlotSpecrograms([0,fmax],spectrogram[:,::-1,:], tvec,modes, logscale=False)


def LSQfit(tvec,data):
   
    data = signal.detrend(data,axis=0)
    data/= std(data,axis=0)
 
    
    nch = size(data,1)
    theta0 = 2*pi*r_[arange(16)/16.,arange(4)/4.]
    #TODO pro každé jinou aplitudu a offset? 
    def Model(par,tvec,show_plt=False):
	A = par[3]
	f = par[0]

	phi = par[1]
	

	M =  int(par[2])
	if M != 0:
	    par[4:][par[4:] <-pi/2] +=pi/2
	    par[4:][par[4:] > pi/2] -=pi/2


	Dtheta =  par[4:]
	Dtheta-= median(Dtheta)

	theta = theta0[:nch]+Dtheta
	model = sin(2*pi*(tvec *f)-theta[:,None]*M+phi)*A

	if show_plt:
	    plot(model.T/2+theta0[None,:nch]/2*pi,'--',linewidth=0.3)
	    plot(data/2+theta0[None,:nch]/2*pi,linewidth=0.3)
	return linalg.norm((model-data.T)/std(data,axis=0)[:,None])+linalg.norm(diff(Dtheta))

    f = linspace(0,1e6/2, size(data,0)/2+1)
    F = mean(abs(fft.rfft(data*hamming(size(data,0))[:,None],axis=0)),axis=1)
    i = argmax(F)
    f_max = sum((f*F)[i-1:i+2])/sum(F[i-1:i+2])

    f_max = max(29000,f_max)
    
    theta_ = median(load('phi2.npy'),axis=0)
    theta__ = median(load('phi.npy'),axis=0)


    x0 = r_[f_max,pi,-2,0.8,theta_[:nch]]
   
    
    P, fopt, direc, _, _, warnflag = fmin_powell(Model, x0, args=(tvec,),maxiter=1e5,maxfun=1e6, xtol=1e-6, ftol=1e-6,full_output=True)
  
    fig = figure()
    title(fopt)

    Model(P,tvec,show_plt=True)
    fig.show()
    

    fig =figure('phase')
    plot(theta0[:16], P[4:20]-median(P[4:20]))
    plot(theta0[:16], theta_[:16],':')
    plot(theta0[:16], theta__[:16],'-.')

    if nch == 20:
	plot(theta0[16:], theta__[16:],'-.')
	plot(theta0[16:20], P[20:24],'--')

    xlim(0,2*pi)
    fig.show()
    pause(0.1)
    return P[4:]
 
    


def main():
    
    for path in ['graphs', 'mp3']:
	if not os.path.exists(path):
	    os.mkdir(path)
	    
    if len(sys.argv)==1 or sys.argv[1] ==  "plots":
	
	#tvec,sig = LoadSignal(12686)
    	tvec,sig = LoadSignal2()

	hop = max((tvec[-1]- tvec[0])/800, 1e-5)
	multi_cwt(sig, 1e6, 5e-4, hop, tvec = tvec)
	
	from pygolem_lite.modules import saveconst

	os.system('convert -resize 150x120\! graphs/spectrogram0.png icon.png')
	saveconst('status', 0)





if __name__ == "__main__":
    main()
   
