#!/usr/bin/env python
# -*- coding: utf-8 -*-
from matplotlib.pyplot import *
from numpy import *
#from scipy.fftpack import rfft, irfft,fftshift,ifft
from numpy.fft import rfft, irfft,fftshift,ifft
from _extend import extend
import sys
from scipy import signal   

"""
    Continuous Multi-Wavelet analyzis  0.1
    
    Description:
    The continius wavelet analysis od the multiple signal is calculated. Instead
    of the single one dimensional wavelet a n-dimensional Morlet wavelet with shifted component is created.
    The moving scalar product with such wavelets helps to identify and amplify the  correctly time shifted 
    signals and suppress the other. 



"""




__all__ = ["NTM_CWT", "angularfreq", "scales", "compute_s0"]


def angularfreq(N, dt):
    """Compute angular frequencies.

    :Parameters:   
      N : integer
        number of data samples
      dt : float
        time step
    
    :Returns:
      angular frequencies :  1d numpy array
    """

    # See (5) at page 64.
    
    N2 = N / 2.0
    w = empty(N)

    for i in range(w.shape[0]):       
        if i <= N2:
            w[i] = (2 * pi * i) / (N * dt)
        else:
            w[i] = (2 * pi * (i - N)) / (N * dt)

    return w


def scales(N, dj, dt, s0):
    """Compute scales.

    :Parameters:
      N : integer
        number of data samples
      dj : float
        scale resolution
      dt : float
        time step

    :Returns:
      scales : 1d numpy array
        scales
    """

    #  See (9) and (10) at page 67.

    J = floor(dj**-1 * log2((N * dt) / s0))
    s = empty(J + 1)
    
    for i in range(s.shape[0]):
        s[i] = s0 * 2**(i * dj)
    
    return s


def compute_s0(dt, p):
    """Compute s0.
    
    :Parameters:    
      dt :float
        time step
      p : float
        omega0 ('morlet') or order ('paul', 'dog')


    :Returns:    
      s0 : float
    """


    return (dt * (p + sqrt(2 + p**2))) / (2 * pi)
    
   

    
    
    
    
    
  

def NTM_CWT((x, dt, dj, p,m , res, fmin,fmax)):

    n_detec = size(x,1)
  
    lenght = size(x,0)
    lenght_ext = int(2**ceil(log2(lenght)))

    x = signal.detrend(x, axis=0, type='linear') 

	
    x_new =  extend(x, method='zeros')


   
    w = angularfreq(lenght_ext, dt)
    s0 = compute_s0(dt, p)
    s = scales(lenght_ext, dj, dt, s0)
    freq = (p + sqrt(2.0 + p**2))/(4*pi * s)	
    
    ind = where((freq>fmin)&(freq<fmax))
    s = s[ind]
  
    x = rfft(x_new, axis=0)
    x[::2]*= -1  #magic :-)


    step  = max(int(lenght_ext/res),1)
   
    stmp = zeros(1)
 
    faze = arange(n_detec)/double(n_detec)*m
    faze = matrix(faze)
    
    wft = zeros((len(w)),dtype=cfloat )
    

    exp_phase = exp(-2*pi*1j*faze)
    f_max = 1/dt	
    spec = list()
    for i in range(len(s)):
	#sys.stdout.write('\r %.1f%%'%(i*100./(len(s))))
        #sys.stdout.flush()
	interv = (((abs(s[i]*w[:len(w)/2]-p)) < 5))	
	arg = s[i]*w[interv]-p 
	
	arg += 1e-2/(1-w[interv]/w[len(w)/2]+0.001)

	wavelet= (1+sign(w[interv]))*exp(-(arg )**2/2)*sqrt(abs(s[i])/dt) 
	
	#tady by šlo udělat aby pro každou frekvenci bylo jiné okno
	wavelet = matrix(wavelet, copy = False)

	wft[interv] = sum(multiply(x[interv],wavelet.T*exp_phase),1) 
	wft = ifft(wft)
	
	spec.append(fftshift(wft[::step]))
	
	wft[:] = 0
	

    spec = array(spec, copy = False)
    spec/= n_detec
    
    n_edge = (lenght_ext/float(lenght)-1)*size(spec,1)/2
   
    spec = spec[:,int(n_edge):-int(n_edge)]
  

  
    s*=max(abs(m),1)
  
    return spec, s
    
    
    
    
    
    
    
    
    
      
    
    
    
    
    
    
    
    
    
    
    

