#!/usr/bin/env python
# -*- coding: utf-8 -*-


# ================  HXR signal deconvolution algorithm  0.1 ============================================
#  algorithm based on the naive implementation of the blind deconvolution algorithm to increase response time of the HXR detector. 
# in the first step the peak positions are estimated by function peaksCounter
# in the second step a "back deconvolution" is done by function back_deconvolution  top find the system response function
# in the third step a really sophisticate algorithm is used for signal deconvolution. 
# in final step a function peaksCounter is used to identify peaks and prepare data for hystorogram.

# function: back_deconvolution;  data - signal from HXR detector,events_pos - some of the identified peaks,events_area - and their surface,
#	peak_width - estimated width of the response function support,lam - regularization parameter,upsample - upsampling parameter to increase resulution and precision, 
#  description - this function calculate inversion problem to the deconvolution. Therefore it is quite well conditioned. It also supports corrupted data in signal (nans), 

#function: deconvolution; data - signal from HXR detector,response  - response function of the system , 
#	win - width od the windows used for calculations,lam - regularization parameter , upsample- upsampling parameter to increase resulution and precision, 
#description - deconvolution algorithm based on minimum tikhonov information. It is nonlinear, weighted by size of the deconvoltioned signal. This way also a positiovity constrain is enforced. 
# this algorithm can deal with data currupted by nans (for example if DAS overcame its maximal range), moreover the constant offset over the window is threated. And finally an upsampling is supported. 

#Autor: Tomas Odstrcil 2012 tomasodstrcil at gmail.com


from numpy import *
from numpy import linalg
import time
from scipy import sparse
from numpy import linalg
import matplotlib
matplotlib.rcParams['backend'] = 'Qt4Agg' 
from matplotlib.pyplot import *
from scipy.interpolate import interp1d


from scikits.sparse.cholmod import cholesky, analyze, cholesky_AAt


from numpy.matlib import repmat
from numpy import *
from matplotlib.pyplot import *
from scipy.sparse import *
import time

MAD = lambda x: median(abs(x-median(x, axis = 0)), axis = 0)*1.48


def peaksCounter(signal, threshold_min,threshold_max ):
    
    """
identifikuje to peak jeden po druhém od největšího po nejmeší
    
    """
    signal = copy(signal)
        
    x0 = argmax(signal)
    peaks = list()
    std =  MAD(signal)
    signal-= median(signal)
 
    while(signal[x0] > threshold_min):
	a = signal[x0]
	
	
	xl= x0
	xr= x0
	while xl > 0 and signal[xl]>=signal[xl-1]:
	    xl-=1
	while xr+1 < len(signal) and signal[xr]>=signal[xr+1]:
	    xr+=1
		
	interv = arange(xl,xr+1)
	x0 -= (signal[x0+1]-signal[x0-1])/(signal[x0+1]-2*a+signal[x0-1])/2
	area = sum(signal[interv])
	#plot(signal[interv])
	#show()
	signal[interv] = 0
	x0 = argmax(signal)

	if a < threshold_max:	
	    peaks.append((x0,  area))
	
    return array(peaks, copy = False)
    
    
#z odhalých poloh to zkusí typnout konvoluční funkci
def back_deconvolution(  data,events_pos,events_area,peak_width,lam,upsample): #Vyřešit to nějak vážení?
    data = copy(data)
    i_nan = isnan(data)
    data-= median(data[~i_nan])
    

    n = len(data)   
    n_peaks = len(events_pos)
    lam = lam*upsample#*n_peaks**2

    npix = upsample*peak_width
    Events_matrix = sparse.lil_matrix(( n,npix))
    
    for i in range(n_peaks):
	l = max(int(events_pos[i]-20),0)
	r = min(int(events_pos[i]+80),n)
	s = int((events_pos[i]-floor(events_pos[i]))*upsample)

	for j in range(r-l):
	    Events_matrix[l+j,j*upsample+upsample-s-1] =  events_area[i]   
	    
	    
    nogaps = spdiags(int_(1-i_nan),0, n, n,format='csr')
    Events_matrix = nogaps*csc_matrix(Events_matrix)
    diag_data = ones((2,npix))
    diag_data[1,:]*=-1
    D = spdiags(diag_data, (0,1), npix, npix,format='csr')
    DD = D.T*D
    EE = Events_matrix.T*Events_matrix
    normEE = linalg.norm(EE.todense())
    
    #print lam, upsample, n_peaks, normEE/1e4

    factor = cholesky(Events_matrix.T*Events_matrix+lam*normEE/1e4*DD)
    g = squeeze(factor( Events_matrix.T*data))
    
    #plot(g)
    #show()
    #plot(data)
    #plot(Events_matrix*g)
    #show()
    return copy(g)
      
    
    
    





    
def deconvolution(data,response, win,lam,upsample):
    t_start = time.time()
    win = 2*(win/2)
    n = len(data)
    n_resp = len(response)#*upsample

    #fun = interp1d(arange(0,n_resp,upsample),response,kind='cubic',bounds_error = False,fill_value = 0)
    #response = fun(arange(n_resp))

    data-=median(data)
    response/= abs(sum(response))
    response*= upsample
    n_ext = (n/win+1)*win
    ext_signal = empty(n_ext)
    ext_signal[(n_ext-n)/2:-(n_ext-n)/2] = data
    ext_signal[:(n_ext-n)/2+1] = median(data[~isnan(data)][:win/2]) 
    ext_signal[-(n_ext-n)/2-1:] = median(data[~isnan(data)][win/2:]) 
    
    intervals = arange(0,n_ext+win/2, win/2)


    i_nan = isnan(ext_signal)
    ext_signal[i_nan] = 0

    deconv = zeros(n_ext*upsample)
    retrofit = zeros_like(ext_signal)

    
    npix = upsample*win
    diags = arange(argmax(response)-n_resp+1, argmax(response)+1)

    diag_data = repmat(response[::-1], npix,1).T
    
    
    ConvMatrix = spdiags(diag_data, diags, npix, npix,format='csc')
    
    
    ReductMatrix =  kron(eye(win,win),eye(1,upsample),format='csr')
    ConvMatrix = ReductMatrix*ConvMatrix

    #příprava dat do hlavného algoritmu na dekonvoluci, řeší se soustava Tg = f pomocí tichonovovy regularizace s nelineárním vážením bez shlazování (minimalizuje se norma g)

    T = ConvMatrix   
    f = ext_signal

    # tohle to upravuje signál tak aby se od něj automaticky odečítal konstantní offset
    a = hstack([sparse.csc_matrix(ones([win,1])), sparse.csc_matrix((win, npix-1))],format='csr')
    
    #T=T+a
    TT_ = T.T*T
    #TT = (T+a).T*(T+a)
    #factor = analyze(TT)

    
    
    g = zeros(npix)
    g = exp(-linspace(-1,1,npix )**2)
    #print sum(ConvMatrix*g), sum(g)
    #plot(g)
    #plot(ConvMatrix*g)
    #show()
    #exit()
    
    for i in range(len(intervals)-2):
	
	t0 = time.time()	

	W = ones(npix) #váhová matice
	i_nan_win = i_nan[intervals[i]:intervals[i+2]]

	gaps = spdiags(int_(i_nan_win),0, win, win,format='csr')
	
	t = time.time()
	#T0 = gaps*T
	#T0T = T0.T*T
	#dTT = T0T+T0T.T-T0.T*T0
	#TT_ = TT-dTT	
	
	#i_nan_g = hstack((i_nan_win,i_nan_win))
	#i_nan_g = reshape(i_nan_g, (1,-1))
	i_nan_g = repeat(i_nan_win, 2)
	#exit()
	factor = analyze(TT_+identity(npix,format='csc'))


	t = time.time()
	#print i
	for j in range(5):
	    #print shape(W), shape(i_nan[intervals[i]:intervals[i+2]])
	    #W[i_nan_g] = 0.001*mean(W)
	    W = sparse.spdiags(W,0, npix,npix)	    

	    factor.cholesky_inplace(TT_ +lam*W)
	    g = squeeze(factor(T.T*f[intervals[i]:intervals[i+2]]))

	    g_tmp = copy(g)
	    g_tmp[g_tmp < 0.01] = 0.01  # ořezání záporných bodů  (výsledek požadujeme jenom kladný)
	    #g_tmp[i_nan[intervals[i]:intervals[i+2]]] = 1000
	    W = mean(g_tmp)/(g_tmp)   # nová váhová matice do další iterace
	#print "time",  time.time() - t,  t-t0
	
	f_i = T*g
	if any(f_i > 0.2):
	    print sum(f_i-median(f_i))/sum(abs(f_i-median(f_i)))*2, sum(g), sum(abs(f_i-median(f_i)))

	    #plot(linspace(0,win,win),cumsum( f_i-median(f_i)))
	    #plot(linspace(0,win,win*upsample), cumsum(g*upsample))
	    #show()

	deconv[(intervals[i]+intervals[i+1])/2*upsample:(intervals[i+1]+intervals[i+2])/2*upsample] = g[len(g)/4:-len(g)/4]
	retrofit[(intervals[i]+intervals[i+1])/2:(intervals[i+1]+intervals[i+2])/2] = f_i[len(f_i)/4:-len(f_i)/4]
    print time.time()-t_start

    deconv = deconv[(n_ext-n)/2*upsample:-(n_ext-n)/2*upsample]
    retrofit = (retrofit)[(n_ext-n)/2:-(n_ext-n)/2]
    i_nan = (i_nan)[(n_ext-n)/2:-(n_ext-n)/2]

    plot(linspace(0,n,n*upsample),deconv*upsample, label = 'deconv')
    plot(linspace(0,n,n),data, label = 'raw')
    plot(linspace(0,n,n),retrofit, label = 'retrofit')
    plot(i_nan)
    legend()
    #savefig('retrofit.svgz', linewidth=0.3)
    #savefig('retrofit.png', linewidth=0.3)

    show()
    return deconv,retrofit
    
    
    
#TODO zabudovat plasma start, plasma end
    
#TODO   vážení peaků při jejich hledání, ignorování přepálených???
# spuštění algoritmu    
    
    
    

for shot_num in range(0,1):
    #data = loadtxt('HXR_')
    print 'shot_num', shot_num
    #data = loadtxt('HXR_Co.txt')
    try:
	data = loadtxt('./data/'+str(shot_num)+'/HXR_')

	
    except:
	continue
    try:
	plasma_start = loadtxt('./data/'+str(shot_num)+'/PlasmaStart')/1000
	plasma_end   = loadtxt('./data/'+str(shot_num)+'/PlasmaEnd'  )/1000
    except:
	plasma_start = 0
	plasma_end = 4e-2
	
    plasma_start_adv = max(plasma_start-1e-3,0)
    plasma_end_adv = min(plasma_end+1e-3,4e-2)

    dt  = 0.04/len(data)
    Bt_trigger = 5e-3
    data[int(Bt_trigger/dt)-100:int(Bt_trigger/dt)+200 ] = median(data[int(Bt_trigger/dt)-100:int(Bt_trigger/dt)+200 ])#remove trigger peak
 
 
 

    plot(arange(len(data))*dt, data)
    axvline(x = plasma_start)
    axvline(x = plasma_end)
    ylim(-1,10)
    data  = data[int(plasma_start_adv/dt):int(plasma_end_adv/dt) ]
    plot(arange(int(plasma_start_adv/dt), int(plasma_end_adv/dt))*dt,data)
    savefig('./graphs/'+str(shot_num)+'_data.png')
    close()
    
    
    win = 600
    upsample = 1
    peakwidth = 100
    DAS_limit = 10 #[V]
    regularization = 0.1


    peaks = peaksCounter(data,0.1,10)
    if len(peaks) == 0:
	print 'any radiation'
	continue
	
    hist(peaks[:,1],sqrt(len(peaks)))
    ylabel('counts [-]')
    xlabel('energy [keV]')
    xlim([0,amax(peaks[:,1])] )
    savefig('./graphs/'+str(shot_num)+'historogram_.png')
    #show()

    close()

    data[data > DAS_limit] = nan   #ořezané peaky
    response_fun = back_deconvolution(data,peaks[:,0],peaks[:,1],peakwidth,1e3,upsample)
    savetxt('response_fun', response_fun) 
    plot(response_fun)
    savefig('./graphs/'+str(shot_num)+'response_fun.png')  
    #show()
    close()
    savefig

    deconv,retrofit = deconvolution(data,response_fun, win,regularization,upsample)
    deconv/= upsample
    peaks = peaksCounter(deconv,0.2,10)

    hist(peaks[:,1],sqrt(len(peaks)))
    xlim([0,amax(peaks[:,1])] )
    savefig('./graphs/'+str(shot_num)+'historogram.png')
    close()

    
    #TODO vrátit číslo celkové plochy/ délkou plazmatu = průměrný vyzářený výkon