
from scipy.signal import get_window
from numpy import *
import _extend

__all__ = ["stft"]



def stft(x, dt, dTime, p,extmethod='none', extlength='powerof2',res = 1000, fmin = 0, fmax = inf):
    """shor time fourier Tranform.

    :Parameters:   
      x : 1d numpy array
        data
      dt : float
         time step
      dTime : float
         time resolution (smaller values of dTime give finer resolution)
      p : float
        wavelet function parameter
      extmethod : string ('none', 'reflection', 'periodic', 'zeros')
                indicates which extension method to use
      extlength : string ('powerof2', 'double')
                indicates how to determinate the length of the extended data
      res:   int
           computed  radial number of the pixels
      fmin, fmax: float   range where fft will be computed
            
    :Returns:
      (X, scales) : (2d numpy array complex, 1d numpy array float)
                  transformed data, scales

    Example:

    >>> import numpy as np
    >>> import mlpy
    >>> x = np.array([1,2,3,4,3,2,1,0])
    >>> stft(x=x, dt=1, dTime=2, wf='dog', p=2)
    (array([[ -4.66713159e-02 -6.66133815e-16j,
             -3.05311332e-16 +2.77555756e-16j,
              4.66713159e-02 +1.38777878e-16j,
              6.94959463e-01 -8.60422844e-16j,
              4.66713159e-02 +6.66133815e-16j,
              3.05311332e-16 -2.77555756e-16j,
             -4.66713159e-02 -1.38777878e-16j,
             -6.94959463e-01 +8.60422844e-16j],
           [ -2.66685280e+00 +2.44249065e-15j,
             -1.77635684e-15 -4.44089210e-16j,
              2.66685280e+00 -3.10862447e-15j,
              3.77202823e+00 -8.88178420e-16j,
              2.66685280e+00 -2.44249065e-15j,
              1.77635684e-15 +4.44089210e-16j,
             -2.66685280e+00 +3.10862447e-15j,
             -3.77202823e+00 +8.88178420e-16j]]), array([ 0.50329212,  2.01316848]))
    """

    x -=  mean(x)
    length = x.shape[0]
    
    if extmethod != 'none':
        x = _extend.extend(x, method=extmethod, length=extlength)
   
    n = x.shape[0]
    freq = fft.fftfreq(n, d=dt)
    #print fmin,fmax
    ind = (freq[:n/2] >fmin )&(freq[:n/2] <fmax)

    step = floor(1/(dt/dTime))

    spec = zeros((res, length/step),dtype=complex)

    for i in range(size(spec,1)):
        #print i 
        s = zeros(size(x))
        #speed improvement
        minimum = max(0,floor(i*step-3*p))
        maximum = min(n,floor(i*step+3*p))
        
        win = get_window('hamming', maximum-minimum, fftbins=False)
        s[minimum:maximum] = x[minimum:maximum]*exp(-(arange(minimum,maximum)-i*step)**2/double(p)**2)


        s = fft.rfft(s)
        s = s[:n/2]
        s = s[ind]
        #print shape(s), shape(s[:(len(s)/res)*res].reshape(res,len(s)/res).mean(1)), res,len(s)/res, size(spec,1)
        spec[:,i] = s[:(len(s)/res)*res].reshape(res,len(s)/res).mean(1)


    scale = fft.fftfreq(n, d=dt)[:(n+1)/2]
    scale = scale[ind]
    
    
    
    return spec, scale[::(n/2)/res]

