#!/usr/bin/python2
# -*- coding: utf-8 -*-


""" CREATED: 7/2012
    AUTHOR: Tomas Odstrcil
    version 2.0 7/2013 - decreased CPU time consumption
"""


import matplotlib
matplotlib.rcParams['backend'] = 'Agg'


from numpy import *
from matplotlib.pyplot import *
from matplotlib.ticker import MultipleLocator, FormatStrFormatter,LogLocator
import os,sys
from pygolem_lite import load_adv, saveconst, Shot
from  multiprocessing import Process, Pool, cpu_count
from scipy.stats.mstats import mquantiles
from scipy.signal import *
from CWT import cwt
import numexpr as ne

import time
import re

matplotlib.rcParams['xtick.direction'] = 'out'
matplotlib.rcParams['ytick.direction'] = 'out'

matplotlib.rcParams['xtick.major.size'] = 10
matplotlib.rcParams['xtick.minor.size'] = 7

matplotlib.rcParams['ytick.major.size'] = 10
matplotlib.rcParams['ytick.minor.size'] = 7

#############
DAS  = 'nistandard' # 'nistandard6132'
##########


class LogFormatterTeXExponent(LogFormatter, object):
    """Extends pylab.LogFormatter to use 
    tex notation for tick labels."""
    
    def __init__(self, *args, **kwargs):
        super(LogFormatterTeXExponent, 
              self).__init__(*args, **kwargs)
        
    def __call__(self, *args, **kwargs):
        """Wrap call to parent class with 
        change to tex notation."""
        label = super(LogFormatterTeXExponent, 
                      self).__call__(*args, **kwargs)
	label = label.replace('-','')
        x = float(label)
        if abs(log10(x)) >2:                      
            label = re.sub(r'e(\S)0?(\d+)',r'\\cdot 10^{\1\2}',str(label))
            label = "$" + label + "$"
        else:
            n = max(0,-int(log10(x)-1))
            label = ('$%.'+str(n)+'f$')%x
        return label
	



shot = Shot()['shotno']
if shot > 12079:
    calib =-array([  261,  261,  -261, 261]) #T/(V*s)
else:
    calib =-array([  261,  261,  -261, -261]) #T/(V*s)

def PlotSpecrograms(freq, field,signals, t, start,end,plasma, contrast, logscale=True):
    print 'PlotSpecrograms'
    
        
    path = './graphs/'
    if not os.path.exists(path): os.makedirs(path)
    

    field = log(1+contrast*abs(field))
    #field = log(1+(field/std(field))**2)/2
    
    vmin = amin(field)
    vmax = amax(field)
    
    fig = figure('spectrogram')

    ax = fig.add_axes([0.1, 0.1, 0.8, 0.85])
    if logscale:
	ax.set_yscale('log', nonposy='clip')
	#ax.yaxis.set_minor_formatter(LogFormatterTeXExponent(base=10,
	    #labelOnlyBase=False))        
	#ax.yaxis.set_major_formatter(LogFormatterTeXExponent(base=10,
	    #labelOnlyBase=False))
	    
    img = ax.imshow(zeros((1,1)), extent=[t[0]*1e3,t[-1]*1e3 ,freq[-1], freq[0]], aspect='auto',vmin=vmin, vmax=vmax,interpolation='bicubic') 
    minorLocator = MultipleLocator(1)
    #print freq

    ax.xaxis.set_minor_locator(minorLocator)
    ax.axis([t[0]*1e3,t[-1]*1e3,freq[-1], freq[0]])
    ax.set_xlabel('time [ms]')
    ax.set_ylabel('Frequency [Hz]')
    ax.axvline(x=start,c='w' ,ls='--')
    ax.axvline(x=end  ,c='w' ,ls='--')


    t_ = time.time()
    for i in range(size(field,-1)):
	img.set_data((field[...,i]))
	fig.savefig(path+'spectrogram_'+str(i)+'.png')
	
    print 'plotted spectograms', time.time()-t_


    fig.clf()
    
    
    ax = fig.add_subplot(111)
    

   
    data_plot, = ax.plot(1e3*t,t,'k',lw=0.1)
    
    if plasma:
	ax.axvline(x=start, c='r',ls='--')
	ax.axvline(x=end  , c='r',ls='--')

    ax.set_xlabel('time [ms]')
    ax.set_ylabel('B [T/s]')
    ax.set_xlim(1e3*t[0],1e3*t[-1])

		

 
	
    for i in range(size(signals,1)):

	minimum = mquantiles(signals[:,i],0.01)
	maximum = mquantiles(signals[:,i], 0.99)

	minimum -= (maximum - minimum)*0.2
	maximum += (maximum - minimum)*0.2
	data_plot.set_ydata(signals[:,i])

	ax.set_ylim(minimum, maximum)

        fig.savefig(path+'signal_'+str(i)+'.png')


    fig.clf()
    
    os.system('convert -resize 150 %sspectrogram_1.png icon.png'%path)




def Spectrogram((omega0, frequenciCutOff_min,tvec, signal)):

   
    
    
    dt = (tvec[-1]-tvec[0])/len(tvec)
    N = len(tvec)

    #for i in range(3):
	#filtered = medfilt(signal, 2*floor(N/500)+1)

	#res = abs(signal - filtered)
        #ind = argsort(res)
        #n_out = int(N/400)
        #ind_2 = in1d(range(N), ind[-n_out:]) & (res > 0.2*amax(filtered))
	#signal[ind_2] =  filtered[ind_2]    #remove "n_out" most distant points from moving medianZ
    
    signal[signal > mquantiles(signal, 0.99)] = median(signal)
    signal[signal < mquantiles(signal, 0.01)] = median(signal)

    Nker = int(2*floor(N/500)+1)
    print "Delka jadra medfilt", Nker
   
    for i in range(3):
	filtered = medfilt(signal, Nker)

	res = abs(signal - filtered)
	ind = argsort(res)
	n_out = int(N/400)
	ind_2 = in1d(range(N), ind[-n_out:]) & (res > 0.2*amax(filtered))
	signal[ind_2] =  filtered[ind_2]    #remove "n_out" most distant points from moving medianZ
	

    spect, scale,freq = cwt((signal, dt, 0.05,omega0,1000,frequenciCutOff_min, 1/dt/2))
    
    
    return spect, freq

   


 
 



def main():
    
	    
    if len(sys.argv) == 1 or sys.argv[1] ==  "plots":


	Data = Shot()
	plasma = Data['plasma']

	#if plasma:
	    #start = Data['plasma_start']*1e3
	    #end = Data['plasma_end']*1e3
	#else:
	start = 0
	end = 40
	    

	tvec, data = Shot()[DAS]
	tvec = float_(tvec)
	contrast = 1e9
	omega0 = 25     #higher value => higher frequancy /lower time resolution etc.
	#frequenciCutOff_min = 900 #[Hz]
	frequenciCutOff_min = 1e4 #[Hz]


	startAdv =  max(start*0.8, tvec[0]*1e3)
	endAdv =  min((end-start)*1.2+start, tvec[-1]*1e3)   

	# cut-off signal 
	ind = (tvec >= startAdv*1e-3) & (tvec <= endAdv*1e-3)

	data = data[ind,:]
	tvec = tvec[ind]

	Ndim = size(data,1)
	try:
	    data*= calib[None,:]
	except:
	    print "Calibration failed !!"
	    
	p = Pool(cpu_count())
	out = p.map(Spectrogram,[(omega0,frequenciCutOff_min,tvec,data[:,i]) for i in range(Ndim)])
	p.close()
	freq = out[0][1]

	spectrograms = [abs(out.pop()[0]) for i in range(Ndim)]
	spectrograms.reverse()
	spectrograms = dstack(spectrograms)
	PlotSpecrograms(freq,spectrograms,data,tvec,start,end,plasma, contrast)


	saveconst('status', 0)





if __name__ == "__main__":
    main()
   

