Discharge/DischargeDatabase/Examples/22813/includes/analysis/Magnetics/0411Spectrograms_TO.ON/CWT.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#from matplotlib.pyplot import *
from numpy import *
from multiprocessing import Process, Pool, cpu_count
from fftshift import fftshift
#import fftw3
import pyfftw


"""
    Continuous Multi-Wavelet analyzis  0.001
    
    Description:
    The continius wavelet analysis od the multiple signal is calculated. Instead
    of the single one dimensional wavelet a n-dimensional Morlet wavelet with shifted component is created.
    The moving scalar product with such wavelets helps to identify and amplify the  correctly time shifted 
    signals and suppress the other. 



"""




__all__ = ["cwt",]


def angularfreq(N, dt):
    """Compute angular frequencies.

    :Parameters:   
      N : integer
        number of data samples
      dt : float
        time step
    
    :Returns:
      angular frequencies :  1d numpy array
    """

    # See (5) at page 64.
    N2 = N / 2.0
  
    w = fftshift(2*pi*arange(-N2,N2)/ (N * dt))

    
    
    return w


def scales(N, dj, dt, s0):
    """Compute scales.

    :Parameters:
      N : integer
        number of data samples
      dj : float
        scale resolution
      dt : float
        time step

    :Returns:
      scales : 1d numpy array
        scales
    """


    J = floor(dj**-1 * log2((N * dt) / s0))
    s = empty(J + 1)
    
 
    s = s0 * 2**(arange(s.shape[0]) * dj)
   
    return s


def compute_s0(dt, p):
    """Compute s0.
    
    :Parameters:    
      dt :float
        time step
      p : float
        omega0 ('morlet') or order ('paul', 'dog')


    :Returns:    
      s0 : float
    """


    return (dt * (p + sqrt(2 + p**2))) / (2 * pi)
    
   

    
    
    
    
    
  

def cwt((x, dt, dj, p , res, fmin,fmax)):

    """
    x - transformed signal
    dt - time step
    dj - frequency resolution
    p - tradeoff between spatial and time resolution
    res - number of the timepoint of the spectrogram
    fmin,fmax  - frequency range
    
    """
    lenght_ = size(x,0)
    
    #reduce lenght to increase speed of fft
    div = 2**int(log2(lenght_/256))
    lenght = div*(lenght_/div)   
    
   
    x = csingle(x[:lenght])

   
    w = angularfreq(lenght, dt)
    s0 = compute_s0(dt, p)
    scale = scales(lenght, dj, dt, s0)
    freq = (p + sqrt(2.0 + p**2))/(4*pi * scale)	
    
    ind = (freq>fmin)&(freq<fmax)
    scale =scale[ind]
    freq = freq[ind]
    x = fft.fft(x, axis=0)[:lenght/2]
    x[::2]*= -1  #magic :-)

   
    step  = max(int(lenght/res),1)
    hres = len(w[0:lenght:step])
    spec = zeros((len(scale), hres),dtype=csingle)
    
    pyfftw.interfaces.cache.enable()
    pyfftw.interfaces.cache.set_keepalive_time(30)

    
    #wft = fftw3.create_aligned_array(len(w))
    #out = fftw3.create_aligned_array(len(w))
    #fft_backward = fftw3.Plan(wft,out, direction='backward', flags=['estimate',],nthreads=1)
    
    pyfftw.interfaces.cache.enable()
    pyfftw.interfaces.cache.set_keepalive_time(30)
    
    wft = pyfftw.n_byte_align_empty(len(w), 16,dtype=csingle)
    out = pyfftw.n_byte_align_empty(len(w), 16,dtype=csingle)

    w = w[:len(w)/2]


    fft_backward = pyfftw.FFTW(wft,out, direction='FFTW_BACKWARD', flags=['FFTW_ESTIMATE','FFTW_DESTROY_INPUT'])

    for i,s in enumerate(scale):
	
	i2 = min((5+p)/s*(lenght * dt)/(2*pi),len(w))
	i1 = max((p-5)/s*(lenght * dt)/(2*pi),0)
	slc = slice(int(i1),int(i2))

	wft[:] = 0


	wavelet= exp(-(s*w[slc]-p)**2/2)*sqrt(s/dt)/lenght
	wft[slc] = wavelet*x[slc]
	
	fft_backward()
	

	spec[i,:] = fftshift(out[::step])
	

    return spec, scale, freq
    
    

def DFT(x, y,nf):
    
    Ftheta = arange(-nf/2,nf/2)

    E = exp(-1j*outer(Ftheta,x))
    Fx_ = dot(y,conj(linalg.pinv(E)))*len(Ftheta)
    
    return  Fx_
	

    
def main():
    random.seed(1)
    x = random.rand(1e4,16)/1.e5
    tvec = arange(1e4)*1e-6

    x+= sin(tvec**2*2*pi*1e4**2)[:,None]
    #plot(x)
    #show()

    x = DFT(2*pi*arange(16)/16., x,16)
    x = fftshift(x,axes=1)
    #plot(x)
    
    #plot(x[:,0])
    #show()
    dt = 1e-6
    dj = 0.05
    fmin = 0
    w0 = 10
    fmax = 1e6/2
    import time
    t = time.time()
    
    p = Pool(cpu_count())
    out = map(cwt, [ (x[:,i], dt, dj,20 ,500, fmin,fmax) for i in range(16) ])
    p.close()
    p.join()

    spec, s,f  = out[0]
    #spec, s,f = cwt((x[:,0], dt, dj, w0, 500, fmin,fmax))
    print  (time.time()-t)
    exit()

    #imshow(abs(spec),aspect='auto', interpolation='bicubic')
    #colorbar()
    #show()
    #print t
    
    fig = figure()

    ax = fig.add_axes([0.1, 0.1, 0.8, 0.85])
    ax.set_yscale('log', nonposy='clip')
    img = ax.imshow(abs(spec), extent=[tvec[0],tvec[-1] ,f[-1], f[0]],
    aspect='auto',interpolation='bicubic') 
    minorLocator   = MultipleLocator(1)
    ax.xaxis.set_minor_locator(minorLocator)
    ax.axis([tvec[0],tvec[-1] ,amin(f), amax(f)])
    ax.set_xlabel('time [s]')
    ax.set_ylabel('Frequency [Hz]')
    
    show()
    
    
    
    
    
 
if __name__ == "__main__":
    main()