from typing import Any, Union, Tuple, Optional, Mapping, cast, Dict, Type

# import matplotlib.pyplot as plt
import numpy as np
import copy
import yaml 
from pathlib import Path
import asyncio

from x.simulation.core import TestbenchManager
from x.simulation.data import AnalysisType
from x.simulation.cache import SimulationDB, DesignInstance, SimResults, MeasureResult
from x.simulation.measure import MeasurementManager, MeasInfo
from x.math.interpolate import LinearInterpolator
from x.io.file import write_yaml

from x.simulation.data import SimData, AnalysisType

from x.concurrent.util import GatherHelper
import matplotlib.pyplot as plt

from x3_testbenches.measurement.data.tran import EdgeType
from x3_testbenches.measurement.tran.digital import DigitalTranTB
from x3_testbenches.measurement.pnoise.base import PNoiseTB
from x3_testbenches.measurement.pac.base import PACTB
from x3_testbenches.measurement.data.tran import interp1d_no_nan, EdgeType, get_first_crossings

class ComparatorMM(MeasurementManager):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self._tbm_info: Optional[Tuple[DigitalTranTB, Mapping[str, Any]]] = None

    def initialize(self, sim_db: SimulationDB, dut: DesignInstance) -> Tuple[bool, MeasInfo]:
        raise RuntimeError('Unused')

    def get_sim_info(self, sim_db: SimulationDB, dut: DesignInstance, cur_info: MeasInfo
                     ) -> Tuple[Union[Tuple[TestbenchManager, Mapping[str, Any]],
                                      MeasurementManager], bool]:
        raise RuntimeError('Unused')

    def process_output(self, cur_info: MeasInfo, sim_results: Union[SimResults, MeasureResult]
                       ) -> Tuple[bool, MeasInfo]:
        raise RuntimeError('Unused')

    def setup_tbm(self, sim_db: SimulationDB, dut: DesignInstance, analysis: Union[Type[DigitalTranTB], Type[PNoiseTB]],
                  flip_dm = False) -> Union[DigitalTranTB, PNoiseTB]:
        specs = self.specs
        noise_in_stimuli = specs['noise_in_stimuli']
        delay_stimuli = specs['delay_stimuli']
        tbm_specs = copy.deepcopy(dict(**specs['tbm_specs']))
        tbm_specs['dut_pins'] = list(dut.pins.keys()) #list(dut.sch_master.pins.keys())

        swp_info = []
        for k, v in specs.get('swp_info', dict()).items():
            if isinstance(v, list):
                swp_info.append((k, dict(type='LIST', values=v)))
            else:
                _type = v['type']
                if _type == 'LIST':
                    swp_info.append((k, dict(type='LIST', values=v['values'])))
                elif _type == 'LINEAR':
                    swp_info.append((k, dict(type='LINEAR', start=v['start'], stop=v['stop'], num=v['num'])))
                elif _type == 'LOG':
                    swp_info.append((k, dict(type='LOG', start=v['start'], stop=v['stop'], num=v['num'])))
                else:
                    raise RuntimeError
        
        if analysis is DigitalTranTB:
            tbm_specs['src_list'] = []
            tbm_specs['pulse_list'] = list(tbm_specs['pulse_list'])
            tbm_specs['pulse_list'].extend(delay_stimuli)
            if flip_dm: 
                tbm_specs['sim_params']['v_dm'] = -1*tbm_specs['sim_params']['v_dm'] 
        else:
            tbm_specs['src_list'] = list(tbm_specs['src_list'])
            tbm_specs['src_list'].extend(noise_in_stimuli)

        tbm = cast(analysis, sim_db.make_tbm(analysis, tbm_specs))
        return tbm

    @staticmethod
    def process_output_noise(noise_sim_results: Union[SimResults, MeasureResult],) -> MeasInfo:
        
        data_noise = cast(SimResults, noise_sim_results).data
        # -- Process pnoise --
        data_noise.open_group('pnoise')
        freq = data_noise['freq']
        noise = data_noise['out']
        orig_shape = noise.shape
        num_swps  = np.prod(orig_shape[:-1])
        noise_fd = np.square(noise[0])
        noise_fd = noise_fd.reshape(num_swps, orig_shape[-1])
        tot_noise_list = []
        for n in noise_fd:
            noise_fit = LinearInterpolator([freq], n, [0])
            tot_noise = np.sqrt(noise_fit.integrate(noise_fit.input_ranges[0][0], noise_fit.input_ranges[0][1]))
            tot_noise_list.append(tot_noise)
        
        tot_noise_list = np.array(tot_noise_list).reshape(orig_shape[:-1])
        return MeasInfo('done', {'Output noise': tot_noise_list, 
                                 'Input Ref noise': tot_noise_list/50})

    @staticmethod
    def process_output_delay(sim_results: Union[SimResults, MeasureResult], tbm: DigitalTranTB, t_per:float) -> MeasInfo:
        data = cast(SimResults, sim_results).data
        t_d = process_differential(tbm, data, 'clk', 'outn', 'outp',EdgeType.RISE, EdgeType.CROSS, '2.25*t_per', 't_sim')

        avg_power_list = []
        for run in range(len(data['VDC_VDD:p'])):
            interp_current = interp1d_no_nan(data['time'], data['VDC_VDD:p'][run])
            times = [t_per*2]
            delta_t = []
            values = [interp_current(t_per*2)]
            for t, i in zip ( data['time'], data['VDC_VDD:p'][run]):
                if t> times[-1]:
                    delta_t.append(t-times[-1])
                    times.append(t)
                    values.append(abs(i))                    
            avg_power = data['VDD'][0][0]*np.sum(np.array(delta_t)*np.array(values[:-1]))/(times[-1]-times[0])
            avg_power_list.append(avg_power)
        
        if check_functionality(tbm, data, 'clk', 'outn', 'outp',EdgeType.RISE,
                                EdgeType.CROSS, '2.25*t_per', 't_sim', 't_per'):
            result = dict(td=t_d,
                      avg_power=np.array(avg_power_list))
        else:
            result = dict(td=np.array([1]),
                      avg_power=np.array([1]))
        result = dict(td=t_d,
                      avg_power=np.array(avg_power_list))
        return MeasInfo('done', result)


    @staticmethod
    async def _run_sim(name: str, sim_db: SimulationDB, sim_dir: Path, dut: DesignInstance,
                       tbm: DigitalTranTB):
        sim_id = f'{name}'
        print(dut)
        sim_results = await sim_db.async_simulate_tbm_obj(sim_id, sim_dir / sim_id,
                                                          dut, tbm, {}, tb_name=sim_id)

        return sim_results

    async def get_noise(self, name, sim_db: SimulationDB, dut: DesignInstance, sim_dir: Path, vcm: float):
        """
            Fast noise simulation using PNoise and the gain found with PACTB
        """
        self.specs['tbm_specs']['sim_params']['v_vcm'] = vcm
        tbm_noise = self.setup_tbm(sim_db, dut, PNoiseTB)
        noise_results = await self._run_sim(name + '_noise', sim_db, sim_dir, dut, tbm_noise)
        # tbm_pac = self.setup_tbm(sim_db, dut, PACTB)
        # pac_results = await self._run_sim(name + '_pac', sim_db, sim_dir, dut, tbm_pac)
        noise = self.process_output_noise(noise_results).prev_results#, pac_results).prev_results

        return noise
    
    async def get_offset(self, name, sim_db: SimulationDB, dut: DesignInstance, sim_dir: Path, voff: float):
        """
            Find the offset, assuming no noise
        """
        self.specs['tbm_specs']['sim_params']['v_offset'] = voff
        tbm = self.setup_tbm(sim_db, dut, DigitalTranTB)
        results = await self._run_sim(name + '_offset', sim_db, sim_dir, dut, tbm)
        data = cast(SimResults, results).data
        functional, swing_side = check_functionality(tbm, data, 'clk', 'outn', 'outp',EdgeType.RISE,
                                EdgeType.CROSS, '2.25*t_per', 't_sim', 't_per')
        return [functional, swing_side, voff]

    
    async def async_measure_performance(self, name: str, sim_dir: Path, sim_db: SimulationDB,
                                        dut: Optional[DesignInstance]) -> Dict[str, Any]:
        results = dict()
        functional_circuit = True
        if 'offset' in self.specs['analysis']: 
            if 'v_offset' in self.specs['sweep_info'].keys():
                voffset_swp_info = self.specs['sweep_info']['v_offset']
            
                tasks = []
                v_offset_list = sorted(list(
                        np.linspace(float(voffset_swp_info['start_lin']), float(voffset_swp_info['stop_lin']), voffset_swp_info['num'])) + voffset_swp_info["add_points"])
                for voff in v_offset_list:
                    decimal_part = int(round((voff - int(voff)) * 1000))
                    tasks.append(self.get_offset(name + f'_voff{decimal_part:03d}', sim_db, dut, sim_dir, voff))
                    
                offset_results = await asyncio.gather(*tasks) #helper.gather_err()
                offset_results = sorted(offset_results, key=lambda x: x[2])
                if offset_results[0][1] == 1: 
                    measured_offset = 2*min(v_offset_list)

                for i in range(1, len(offset_results)):
                    func_prev, state_prev, voff_prev = offset_results[i-1]
                    func_curr, state_curr, voff_curr = offset_results[i]
                    if not func_curr or not func_prev: 
                        measured_offset = 2*max(v_offset_list)
                        functional_circuit = False
                        break
                    if state_prev == 0 and state_curr == 1:
                        measured_offset = (voff_curr)
                        break
                    if i==(len(offset_results)-1):
                        measured_offset = 2*max(v_offset_list)
                print("Measured offset", measured_offset)
                self.specs["tbm_specs"]['sim_params']['v_offset'] = measured_offset+1e-3
                print("offset check" , abs(measured_offset), max(v_offset_list))
                if abs(measured_offset) >= abs(max(v_offset_list)):
                    functional_circuit = False

        if functional_circuit and ('noise' in self.specs['analysis']):
            try:
                if 'v_vcm' in self.specs['swp_info'].keys():
                    vcm_swp_info = self.specs['swp_info']['v_vcm']
                    helper = GatherHelper()
                    for vcm in list(
                            np.linspace(float(vcm_swp_info['start']), float(vcm_swp_info['stop']), vcm_swp_info['num'])):
                        helper.append(self.get_noise(name + f'_noise_vcm={vcm:.2f}', sim_db, dut, sim_dir, vcm))
                    results['noise'] = await helper.gather_err()
                else:
                    tbm_noise = self.setup_tbm(sim_db, dut, PNoiseTB)
                    noise_results = await self._run_sim(name + '_noise', sim_db, sim_dir, dut, tbm_noise)
                    #tbm_pac = self.setup_tbm(sim_db, dut, PACTB)
                    #pac_results = await self._run_sim(name + '_pac', sim_db, sim_dir, dut, tbm_pac)
                    results['noise'] = self.process_output_noise(noise_results).prev_results #, pac_results).prev_results
            except: 
                results["noise"] = {'Input Ref noise': np.array([1]),
                                        'Output noise': np.array([1])}


        if functional_circuit and ('delay' in self.specs['analysis']):
            tbm_delay = self.setup_tbm(sim_db, dut, DigitalTranTB)
            delay_results = await self._run_sim(name + '_delay', sim_db, sim_dir, dut, tbm_delay)
            data = cast(SimResults, delay_results).data
            results['delay'] = self.process_output_delay(delay_results, tbm_delay, self.specs['tbm_specs']['sim_params']['t_per']).prev_results
            
        try:
            result_yaml = self.specs['result_file']

            if functional_circuit:
                input_noise = results['noise']['Input Ref noise'].flatten()[0]
                delay = results['delay']['td'].flatten()[0]
                results_dict = {'input_noise': float(input_noise), 'delay': float(delay), 
                                'avg_power': float(results['delay']['avg_power'].flatten()[0]), 
                                'offset': float(measured_offset)}
            else:
                results_dict = {'input_noise': 1, 'delay': 1, 
                                'avg_power': 1, 'offset':float(measured_offset)}

            print(results_dict)
            with open(result_yaml, "w") as f:
                yaml.dump(results_dict, f, default_flow_style=False)
        except:
            print("results will not be dumped to yaml")
            #write_yaml(result_yaml, results_dict)
        return results #cast(SimResults, delay_results).data

class ComparatorDelayMM(ComparatorMM):
    def commit(self) -> None:
        self.specs['analysis'] = ['delay']

    @classmethod
    def plot_sig(cls, sim_data, axis):
        time_vec = sim_data['time']
        for sig_name in ['inn', 'inp', 'clk']:
            axis[0].plot(time_vec, sim_data[sig_name], linewidth=2, label=sig_name)
        for sig_name in ['outn', 'outp']:
            axis[1].plot(time_vec, sim_data[sig_name], label=sig_name)
        [_ax.grid() for _ax in axis]
        [_ax.legend() for _ax in axis]

    @classmethod
    def plot_vcm_vdm(cls, sim_data, td, axis):
        if 'v_cm' not in sim_data.sweep_params or 'v_dm' not in sim_data.sweep_params:
            raise RuntimeError
        for idx, vcm in enumerate(sim_data['v_cm']):
            axis.plot(sim_data['v_dm'], td[0, idx, :], label=f'vcm={vcm}V')
        axis.set_xlabel('v_dm')
        axis.set_ylabel('Resolve Time')
        axis.grid()
        axis.legend()

    @staticmethod
    def plot_vcm(sim_data, td, axis, tbm: DigitalTranTB):
        if 'v_dm' in sim_data.sweep_params:
            raise RuntimeError("Only sweep vcm, vdm is also in sweep params now")
        vdm = tbm.get_sim_param_value('v_dm')
        axis.plot(sim_data['v_vcm'], td[0], label=f'vdm={vdm}')
        axis.xlabel('v_cm')
        axis.ylabel('Resolve Time')
        axis.grid()
        axis.legend()

class ComparatorPNoiseMM(ComparatorMM):
    def commit(self) -> None:
        self.specs['analysis'] = ['noise']

def process_differential(tbm, data: SimData, in_name: str, outn_name: str, outp_name:str, 
                          in_edge: EdgeType,
               out_edge: EdgeType, t_start: Union[np.ndarray, float, str] = 0,
               t_stop: Union[np.ndarray, float, str] = float('inf')):
    thres_delay = 0.5   
    specs = tbm.specs
    rtol: float = specs.get('rtol', 1e-8)
    atol: float = specs.get('atol', 1e-22)  
    in_0, in_1 = tbm.get_pin_supply_values(in_name, data)
    out_0, out_1 = tbm.get_pin_supply_values(outn_name, data)
    data.open_analysis(AnalysisType.TRAN)
    tvec = data['time']
    in_vec = data[in_name]
    outn_vec = data[outn_name]
    outp_vec = data[outp_name]
    out_vec = abs(outn_vec - outp_vec)  
    # evaluate t_start/t_stop
    if isinstance(t_start, str) or isinstance(t_stop, str):
        calc = tbm.get_calculator(data)
        if isinstance(t_start, str):
            t_start = calc.eval(t_start)
        if isinstance(t_stop, str):
            t_stop = calc.eval(t_stop)  
    vth_in = (in_1 - in_0) * thres_delay + in_0
    vth_out = (out_1 - out_0) * thres_delay + out_0
    in_c = get_first_crossings(tvec, in_vec, vth_in, etype=in_edge, start=t_start, stop=t_stop,
                               rtol=rtol, atol=atol)
    out_c = get_first_crossings(tvec, out_vec, vth_out, etype=out_edge, start=t_start,
                                stop=t_stop, rtol=rtol, atol=atol)
    out_c -= in_c

    return out_c    

def check_functionality(tbm, data: SimData, in_name: str, outn_name: str, outp_name:str, 
                          in_edge: EdgeType,
               out_edge: EdgeType, t_start: Union[np.ndarray, float, str] = 0,
               t_stop: Union[np.ndarray, float, str] = float('inf'), 
               t_per: Union[np.ndarray, float, str] = 0):
    
    thres_delay = 0.5   
    specs = tbm.specs
    rtol: float = specs.get('rtol', 1e-8)
    atol: float = specs.get('atol', 1e-22)  
    in_0, in_1 = tbm.get_pin_supply_values(in_name, data)
    out_0, out_1 = tbm.get_pin_supply_values(outn_name, data)
    data.open_analysis(AnalysisType.TRAN)
    tvec = data['time']
    in_vec = data[in_name]
    outn_vec = data[outn_name]
    outp_vec = data[outp_name]
    out_vec = abs(outn_vec - outp_vec)  

    len_vec = len(tvec)
    max_outn = max(outn_vec.flatten()[len_vec//2:])
    min_outn = min(outn_vec.flatten()[len_vec//2:])
    swing_outn = max_outn-min_outn
    max_outp = max(outp_vec.flatten()[len_vec//2:])
    min_outp = min(outp_vec.flatten()[len_vec//2:])
    swing_outp = max_outp-min_outp

    # evaluate t_start/t_stop/t_per
    if isinstance(t_start, str) or isinstance(t_stop, str) or isinstance(t_per):
        calc = tbm.get_calculator(data)
        if isinstance(t_start, str):
            t_start = calc.eval(t_start)
        if isinstance(t_stop, str):
            t_stop = calc.eval(t_stop) 
        if isinstance(t_per, str):
            t_per = calc.eval(t_per) 
    vth_in = (in_1 - in_0) * thres_delay + in_0
    vth_out = (out_1 - out_0) * thres_delay + out_0
    
    print("check swing magnitude")
    VDD = out_1[0]
    VSS = out_0[0]
    if swing_outn<0.9*VDD and swing_outp<0.9*VDD: 
        return False, None
    
    in_c = get_first_crossings(tvec, in_vec, vth_in, etype=in_edge, start=t_start, stop=t_stop,
                               rtol=rtol, atol=atol)
    out_c = get_first_crossings(tvec, out_vec, vth_out, etype=out_edge, start=t_start,
                                stop=t_stop, rtol=rtol, atol=atol)
    
    in_vec_interp = interp1d_no_nan(data['time'], in_vec)
    outn_vec_interp = interp1d_no_nan(data['time'], outn_vec)
    outp_vec_interp = interp1d_no_nan(data['time'], outp_vec)

    in_ht=in_c+0.45*t_per # when clock goes high the two sides should come to decision
    if in_vec_interp(in_ht) == VDD: 
        print("check decision")
        if abs(outn_vec_interp(in_ht)-outp_vec_interp(in_ht))<0.9*VDD:
            return False, None
        
    in_lt = out_c-0.1*t_per # clock should be low, reset
    if in_vec_interp(out_c-0.1*t_per) == VSS:
        print("check reset", abs(outn_vec_interp(in_lt)-outp_vec_interp(in_lt)))
        if  abs(outn_vec_interp(in_lt)-outp_vec_interp(in_lt))>0.1*VDD:
            return False, None
    
    eval_out = 0 if (np.mean(outn_vec.flatten()) > np.mean(outp_vec.flatten())) else 1
    return True, eval_out