## This code is written by Davide Albanese, <albanese@fbk.eu> and
## Marco Chierici, <chierici@fbk.eu>.
## (C) 2008 Fondazione Bruno Kessler - Via Santa Croce 77, 38100 Trento, ITALY.

## See: Practical Guide to Wavelet Analysis - C. Torrence and G. P. Compo.

## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.

## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.

## You should have received a copy of the GNU General Public License
## along with this program.  If not, see <http://www.gnu.org/licenses/>.

from numpy import *
import _extend

__all__ = ["cwt", "icwt", "angularfreq", "scales", "compute_s0"]


def angularfreq(N, dt):
    """Compute angular frequencies.

    :Parameters:   
      N : integer
        number of data samples
      dt : float
        time step
    
    :Returns:
      angular frequencies :  1d numpy array
    """

    # See (5) at page 64.
    
    N2 = N / 2.0
    w = empty(N)

    for i in range(w.shape[0]):       
        if i <= N2:
            w[i] = (2 * pi * i) / (N * dt)
        else:
            w[i] = (2 * pi * (i - N)) / (N * dt)

    return w


def scales(N, dj, dt, s0):
    """Compute scales.

    :Parameters:
      N : integer
        number of data samples
      dj : float
        scale resolution
      dt : float
        time step

    :Returns:
      scales : 1d numpy array
        scales
    """

    #  See (9) and (10) at page 67.

    J = floor(dj**-1 * log2((N * dt) / s0))
    s = empty(J + 1)
    
    for i in range(s.shape[0]):
        s[i] = s0 * 2**(i * dj)
    
    return s


def compute_s0(dt, p, wf):
    """Compute s0.
    
    :Parameters:    
      dt :float
        time step
      p : float
        omega0 ('morlet') or order ('paul', 'dog')
      wf : string
        wavelet function ('morlet', 'paul', 'dog')

    :Returns:    
      s0 : float
    """
    
    if wf == "dog":
        return (dt * sqrt(p + 0.5)) / pi
    elif wf == "paul":
        return (dt * ((2 * p) + 1)) / (2 * pi)
    elif wf == "morlet":
        return (dt * (p + sqrt(2 + p**2))) / (2 * pi)
    else:
        raise ValueError("wavelet '%s' is not available" % wf)


def cwt(x, dt, dj, wf="dog", p=2, extmethod='none', extlength='powerof2',res = 2000, fmin = 0, fmax = inf):
    """Continuous Wavelet Tranform.

    :Parameters:   
      x : 1d numpy array
        data
      dt : float
         time step
      dj : float
         scale resolution (smaller values of dj give finer resolution)
      wf : string ('morlet', 'paul', 'dog')
         wavelet function
      p : float
        wavelet function parameter
      extmethod : string ('none', 'reflection', 'periodic', 'zeros')
                indicates which extension method to use
      extlength : string ('powerof2', 'double')
                indicates how to determinate the length of the extended data
            
    :Returns:
      (X, scales) : (2d numpy array complex, 1d numpy array float)
                  transformed data, scales

    Example:

    >>> import numpy as np
    >>> import mlpy
    >>> x = np.array([1,2,3,4,3,2,1,0])
    >>> mlpy.cwt(x=x, dt=1, dj=2, wf='dog', p=2)
    (array([[ -4.66713159e-02 -6.66133815e-16j,
             -3.05311332e-16 +2.77555756e-16j,
              4.66713159e-02 +1.38777878e-16j,
              6.94959463e-01 -8.60422844e-16j,
              4.66713159e-02 +6.66133815e-16j,
              3.05311332e-16 -2.77555756e-16j,
             -4.66713159e-02 -1.38777878e-16j,
             -6.94959463e-01 +8.60422844e-16j],
           [ -2.66685280e+00 +2.44249065e-15j,
             -1.77635684e-15 -4.44089210e-16j,
              2.66685280e+00 -3.10862447e-15j,
              3.77202823e+00 -8.88178420e-16j,
              2.66685280e+00 -2.44249065e-15j,
              1.77635684e-15 +4.44089210e-16j,
             -2.66685280e+00 +3.10862447e-15j,
             -3.77202823e+00 +8.88178420e-16j]]), array([ 0.50329212,  2.01316848]))
    """

    #x -=  mean(x)
    lenght = x.shape[0]
    
    if extmethod != 'none':
        x = _extend.extend(x, method=extmethod, length=extlength)
   
    w = angularfreq(x.shape[0], dt)
    s0 = compute_s0(dt, p, wf)
    s = scales(lenght, dj, dt, s0)
    freq = (p + sqrt(2.0 + p**2))/(4*pi * s)	
   
    ind = where((freq>fmin)&(freq<fmax))
    s = s[ind]
    x = fft.rfft(x, axis=0)
    
    step  = max(int(lenght/res),1)
    spec = zeros((len(s),len(w[0:lenght:step]) ),dtype=complex)
    stmp = zeros(1)
    #wavelet = zeros((len(w)),dtype=complex )
    wft = zeros((len(w)),dtype=complex )
    
    
    for i in range(len(s)):
	interv = where((abs(s[i]*w[0:len(w)/2]-p))<3)
	wavelet = (1+sign(w[interv]))*exp(-(s[i]*w[interv]-p)**2/2)*sqrt(abs(s[i])/dt)
	#wavelet[interv] = waveletb.morletft(stmp, w[interv], p, dt, norm = True)
	wft[interv] = x[interv]*wavelet
	wft = fft.ifft(wft )
        spec[i,:] =  wft[0:lenght:step]
	wft[:] = 0

    return spec, s

