#based on Dirigent/Devices/Oscilloscopes/Drivers/TektrMSO5/main58-a.py

import re
import sys
import pyvisa  as visa								# https://pyvisa.readthedocs.io/
import numpy as np
import time


#python==3.7
#numpy==1.20.2
#PyVISA==1.11.3


# global settings
osc_name = "TektrMSO58-a"
visa_address = f"TCPIP::{osc_name}::INSTR"
buffer_len = 1024**2
hires_channels = (2,3,4,5,6)


class Connection:

    def __init__(self, address, visa_backend='@py'):  # default backend: pyvisa-py
        self.rm = visa.ResourceManager(visa_backend)
        self.scope = self.rm.open_resource(address, access_mode=visa.constants.AccessModes.exclusive_lock)
        self.scope.timeout = 10000 # 10 s

    def __del__(self):
        self.scope.close()
        self.rm.close()

    def identify(self):
        print(self.scope.query('*IDN?'))

    def save_screenshot(self, target_fname='screenshot.png', tmp_fname='C:/Temp.png'):
        self.scope.write('SAVE:IMAGe "{}"'.format(tmp_fname))			        # Save image to instrument's local disk
        self.scope.query('*OPC?')							# Wait for instrument to finish writing image to disk
        self.scope.write('FILESystem:READFile "{}"'.format(tmp_fname))			# Read image file from instrument
        time.sleep(0.5)
        img_data = self.scope.read_raw(buffer_len)
        with open(target_fname, 'wb') as fout:
            fout.write(img_data)
        self.scope.write('FILESystem:DELEte "{}"'.format(tmp_fname))			# Delete temporary file

    def available_analog_channels(self):
        ret = self.scope.query('data:source:available?')
        if not ret:
            return []
        channels = [int(c) for c in re.findall('CH(\\d', ret)]
        return channels

    def save_channels(self, target_dt, *channels):
        start_time = time.time()

        data_all = []
        data_dict = {}
        for channel in channels:
            channel = int(channel)
            print(f'{osc_name} getting channel {channel}')
            d = self.get_channel(channel)
            time.sleep(0.5)
            data_all.append(d)
            if channel in hires_channels:
                data_dict['ch{}'.format(channel)] = d

        time_ax, dt = self.get_time_axis()
        data_dict['time'] = time_ax
        data_all.insert(0, time_ax)
        data = np.column_stack(data_all)
        downsample = int(np.rint(float(target_dt) / dt))
        
        if downsample > 1:
            print(f'{osc_name} downsampling by factor {downsample}')
            data = data[::downsample,:]
        for idx, channel in enumerate(channels):
            np.savetxt("ch{}.csv".format(channel), data[:,[0,idx+1]], 
                fmt=(['%.6e']*2), delimiter=',')

        np.savez_compressed('data_all_fullres.npz', **data_dict)

        end_time = time.time()
        print(f"{osc_name} save_channels took {end_time - start_time:.1f} s")

    def get_channel(self, channel_i):
        """Based on script at 
        https://www.tek.com/support/faqs/programing-how-get-and-plot-waveform-dpo-mso-mdo4000-series-scope-python
        """
        self.scope.write('DATA:SOUrce CH{}'.format(channel_i))

        rec_len = int(self.scope.query('HORizontal:RECOrdlength?'))  #record length

        self.scope.write('DATa:START {}'.format(1))                  #download full record
        self.scope.write('DATa:STOP {}'.format(rec_len))


        self.scope.write('DATA:WIDTH 2')
        self.scope.write('DATA:ENCdg RIBINARY')

        ymult = float(self.scope.query('WFMOUTPRE:YMULT?'))
        yzero = float(self.scope.query('WFMOUTPRE:YZERO?'))
        yoff = float(self.scope.query('WFMOutpre:YZEro?'))
        
        ## accodding to 5 Series MSO Programmer Manual is YOFf? always 0.0
        #yoff = float(self.scope.query('WFMOUTPRE:YOFF?'))
        #
        ## replaced with query_binary_values()
        # self.scope.write('CURVE?')
        # data = self.scope.read_raw(buffer_len)
        # headerlen = 2 + int(data[1]) 
        # header = data[:headerlen]
        # # first byte is '#', second is length of string showing number of points
        # data_wfm_str = data[headerlen:-1] # last byte is newline
        # data_wfm = np.frombuffer(data_wfm_str, dtype='>i2')

        data_wfm = self.scope.query_binary_values('CURV?', datatype='h', is_big_endian=True, container=np.array)

        data_volts = (data_wfm - yoff)*ymult + yzero
        return data_volts

    def get_time_axis(self):
        self.scope.write('DATA:SOUrce CH1') # TODO? to be sure we are taking analog
        xincr  = float(self.scope.query('WFMOUTPRE:XINCR?'))
        xzero  = float(self.scope.query('WFMOUTPRE:XZERO?'))
        N_trig =   int(self.scope.query('WFMOutpre:PT_OFF?'))
        N = int(self.scope.query('WFMOutpre:NR_Pt?'))

        time_ax = (np.arange(N)-N_trig)*xincr + xzero
        return time_ax, xincr




known_commands =  set(filter(lambda n: not n.startswith('_'), dir(Connection)))


if __name__ == '__main__':
    con = Connection(visa_address)
    try:
        command = sys.argv[1]
    except IndexError:
        print('supply command (Connection method name) as script argument')
        sys.exit(1)
    try:
        method = getattr(con, command)
    except AttributeError:
        print('Unkwnown command "{}", known commands: {}'.format(command, known_commands))
        sys.exit(1)
    method(*sys.argv[2:])
    sys.exit(0)
