#!/usr/bin/env python
# -*- coding: utf-8 -*-
#from matplotlib.pyplot import *
from numpy import *
from multiprocessing import Process, Pool, cpu_count
from fftshift import fftshift
#import fftw3
import pyfftw
"""
Continuous Multi-Wavelet analyzis 0.001
Description:
The continius wavelet analysis od the multiple signal is calculated. Instead
of the single one dimensional wavelet a n-dimensional Morlet wavelet with shifted component is created.
The moving scalar product with such wavelets helps to identify and amplify the correctly time shifted
signals and suppress the other.
"""
__all__ = ["cwt",]
def angularfreq(N, dt):
"""Compute angular frequencies.
:Parameters:
N : integer
number of data samples
dt : float
time step
:Returns:
angular frequencies : 1d numpy array
"""
# See (5) at page 64.
N2 = N / 2.0
w = fftshift(2*pi*arange(-N2,N2)/ (N * dt))
return w
def scales(N, dj, dt, s0):
"""Compute scales.
:Parameters:
N : integer
number of data samples
dj : float
scale resolution
dt : float
time step
:Returns:
scales : 1d numpy array
scales
"""
J = floor(dj**-1 * log2((N * dt) / s0))
s = empty(J + 1)
s = s0 * 2**(arange(s.shape[0]) * dj)
return s
def compute_s0(dt, p):
"""Compute s0.
:Parameters:
dt :float
time step
p : float
omega0 ('morlet') or order ('paul', 'dog')
:Returns:
s0 : float
"""
return (dt * (p + sqrt(2 + p**2))) / (2 * pi)
def cwt((x, dt, dj, p , res, fmin,fmax)):
"""
x - transformed signal
dt - time step
dj - frequency resolution
p - tradeoff between spatial and time resolution
res - number of the timepoint of the spectrogram
fmin,fmax - frequency range
"""
lenght_ = size(x,0)
#reduce lenght to increase speed of fft
div = 2**int(log2(lenght_/256))
lenght = div*(lenght_/div)
x = csingle(x[:lenght])
w = angularfreq(lenght, dt)
s0 = compute_s0(dt, p)
scale = scales(lenght, dj, dt, s0)
freq = (p + sqrt(2.0 + p**2))/(4*pi * scale)
ind = (freq>fmin)&(freq<fmax)
scale =scale[ind]
freq = freq[ind]
x = fft.fft(x, axis=0)[:lenght/2]
x[::2]*= -1 #magic :-)
step = max(int(lenght/res),1)
hres = len(w[0:lenght:step])
spec = zeros((len(scale), hres),dtype=csingle)
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(30)
#wft = fftw3.create_aligned_array(len(w))
#out = fftw3.create_aligned_array(len(w))
#fft_backward = fftw3.Plan(wft,out, direction='backward', flags=['estimate',],nthreads=1)
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(30)
wft = pyfftw.n_byte_align_empty(len(w), 16,dtype=csingle)
out = pyfftw.n_byte_align_empty(len(w), 16,dtype=csingle)
w = w[:len(w)/2]
fft_backward = pyfftw.FFTW(wft,out, direction='FFTW_BACKWARD', flags=['FFTW_ESTIMATE','FFTW_DESTROY_INPUT'])
for i,s in enumerate(scale):
i2 = min((5+p)/s*(lenght * dt)/(2*pi),len(w))
i1 = max((p-5)/s*(lenght * dt)/(2*pi),0)
slc = slice(int(i1),int(i2))
wft[:] = 0
wavelet= exp(-(s*w[slc]-p)**2/2)*sqrt(s/dt)/lenght
wft[slc] = wavelet*x[slc]
fft_backward()
spec[i,:] = fftshift(out[::step])
return spec, scale, freq
def DFT(x, y,nf):
Ftheta = arange(-nf/2,nf/2)
E = exp(-1j*outer(Ftheta,x))
Fx_ = dot(y,conj(linalg.pinv(E)))*len(Ftheta)
return Fx_
def main():
random.seed(1)
x = random.rand(1e4,16)/1.e5
tvec = arange(1e4)*1e-6
x+= sin(tvec**2*2*pi*1e4**2)[:,None]
#plot(x)
#show()
x = DFT(2*pi*arange(16)/16., x,16)
x = fftshift(x,axes=1)
#plot(x)
#plot(x[:,0])
#show()
dt = 1e-6
dj = 0.05
fmin = 0
w0 = 10
fmax = 1e6/2
import time
t = time.time()
p = Pool(cpu_count())
out = map(cwt, [ (x[:,i], dt, dj,20 ,500, fmin,fmax) for i in range(16) ])
p.close()
p.join()
spec, s,f = out[0]
#spec, s,f = cwt((x[:,0], dt, dj, w0, 500, fmin,fmax))
print (time.time()-t)
exit()
#imshow(abs(spec),aspect='auto', interpolation='bicubic')
#colorbar()
#show()
#print t
fig = figure()
ax = fig.add_axes([0.1, 0.1, 0.8, 0.85])
ax.set_yscale('log', nonposy='clip')
img = ax.imshow(abs(spec), extent=[tvec[0],tvec[-1] ,f[-1], f[0]],
aspect='auto',interpolation='bicubic')
minorLocator = MultipleLocator(1)
ax.xaxis.set_minor_locator(minorLocator)
ax.axis([tvec[0],tvec[-1] ,amin(f), amax(f)])
ax.set_xlabel('time [s]')
ax.set_ylabel('Frequency [Hz]')
show()
if __name__ == "__main__":
main()