import numpy as np
import copy
import matplotlib.pyplot as plt

from pathlib import Path
from typing import Any, Union, Tuple, Optional, Mapping, cast, Dict, Type

from x.simulation.core import TestbenchManager, swp_info_from_struct
from x.simulation.data import SimData, AnalysisType
from x.simulation.cache import SimulationDB, DesignInstance, SimResults, MeasureResult
from x.simulation.measure import MeasurementManager, MeasInfo
from x.math.interpolate import LinearInterpolator
from x.io.file import write_yaml
from x.concurrent.util import GatherHelper

from x_testbenches.measurement.data.tran import EdgeType
from x_testbenches.measurement.tran.base import TranTB
from x_testbenches.measurement.tran.digital import DigitalTranTB
from x_testbenches.measurement.data.tran import interp1d_no_nan, EdgeType, get_first_crossings

from collections import OrderedDict
import yaml
import os

class Ring_OscillatorMM(MeasurementManager):

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self._tbm_info: Optional[Tuple[DigitalTranTB, Mapping[str, Any]]] = None

    def initialize(self, sim_db: SimulationDB, dut: DesignInstance) -> Tuple[bool, MeasInfo]:
        raise RuntimeError('Unused')

    def get_sim_info(self, sim_db: SimulationDB, dut: DesignInstance, cur_info: MeasInfo
                     ) -> Tuple[Union[Tuple[TestbenchManager, Mapping[str, Any]],
                                      MeasurementManager], bool]:
        raise RuntimeError('Unused')

    def process_output(self, cur_info: MeasInfo, sim_results: Union[SimResults, MeasureResult]
                       ) -> Tuple[bool, MeasInfo]:
        raise RuntimeError('Unused')

    def setup_tbm(self, 
                  sim_db: SimulationDB, 
                  dut: DesignInstance, 
                  analysis: DigitalTranTB) -> DigitalTranTB:
        
        specs = self.specs
        tbm_specs = copy.deepcopy(dict(**specs['tbm_specs']))
        tbm_specs['dut_pins'] = list(dut.pins.keys()) 

        tbm = cast(analysis, sim_db.make_tbm(analysis, tbm_specs))
        return tbm

    # NOTE: Simulation methods

    @staticmethod
    async def _run_sim(name: str, sim_db: SimulationDB, sim_dir: Path, dut: DesignInstance,
                       tbm: DigitalTranTB):
        sim_id = f'{name}'
        sim_results = await sim_db.async_simulate_tbm_obj(sim_id, sim_dir / sim_id,
                                                          dut, tbm, {}, tb_name=sim_id)

        return sim_results

    async def async_measure_performance(self, 
                                        name: str, 
                                        sim_dir: Path, 
                                        sim_db: SimulationDB,
                                        dut: Optional[DesignInstance]) -> Mapping[str, Any]:
        
        try:
            result_yaml = self.specs['result_file']
            if os.path.exists(result_yaml):
                os.remove(result_yaml)
                print(f"{result_yaml} deleted successfully.")
            else:
                print("File not found.")
        except:
            print("no yaml spec")
    
        tbm = self.setup_tbm(sim_db, dut, DigitalTranTB)
        result_sim = await self._run_sim(name, sim_db, sim_dir, dut, tbm)
        data = cast(SimResults, result_sim).data

        # Create a results dictionary and add all outputs
        results = dict()
        for elem in data.signals + ['time']:
            results[elem] = data[elem]
        
        try:
            result_yaml = self.specs['result_file']
            measured_perf = Ring_OscillatorMM.process_output_frequency(result_sim, tbm).prev_results

            results_dict = {'frequency': float(measured_perf['freq']),
                            'avg_power': float(measured_perf['avg_power'].flatten()[0]),
                            'avg_duty': float(measured_perf['avg_duty'])}

            with open(result_yaml, "w") as f:
                yaml.dump(results_dict, f, default_flow_style=False)
        except:
            print("results will not be dumped to yaml")
        print(results_dict)
        return results 

    @staticmethod
    def plot_outputs(results) -> None:
        
        time = results['time']
        out1 = results['OUT1'].flatten()
        out2 = results['OUT2'].flatten()
        out3 = results['OUT3'].flatten()
        
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
        ax.plot(time, out1, label='OUT1')
        ax.plot(time, out2, label='OUT2')
        ax.plot(time, out3, label='OUT3')
        ax.set_xlabel('Time [s]')
        ax.set_ylabel('Voltage [V]')
        ax.legend()
        ax.grid(True)
        plt.show()
        plt.savefig('ro.png')
        plt.close()
    
    @staticmethod
    def check_functionality(tbm, data: SimData,  out_name:str): 

        thres_delay = 0.5   
        specs = tbm.specs
        rtol: float = specs.get('rtol', 1e-8)
        atol: float = specs.get('atol', 1e-22)  
        out_0, out_1 = tbm.get_pin_supply_values(out_name, data)
        data.open_analysis(AnalysisType.TRAN)
        tvec = data['time']
        out_vec = data[out_name]
        
        max_out = max(out_vec[0][len(out_vec[0])//2::])
        min_out = min(out_vec[0][len(out_vec[0])//2::])
        swing_out = max_out-min_out

        VDD = out_1[0]
        VSS = out_0[0]
        print("check swing", swing_out)
        if swing_out<0.9*(VDD-VSS): 
            return False
    
        return True
    
    @staticmethod
    def process_output_frequency(sim_results: Union[SimResults, MeasureResult], 
                                 tbm: DigitalTranTB) -> MeasInfo:
        data = cast(SimResults, sim_results).data
        # Extract time and voltage
        print("data", data)

        time = data['time']  # First column: Time (seconds)
        voltage = data['OUT3'].flatten()  # Second column: Voltage

        # Find rising edge transitions (where voltage crosses threshold from low to high)
        threshold = (np.max(voltage) + np.min(voltage)) / 2  + np.min(voltage)# Midpoint threshold
        rising_edges = np.where((voltage[:-1] <= threshold) & (voltage[1:] > threshold))[0]

        # Extract time instances of rising edges
        rising_times = time[rising_edges]

        # Compute periods (differences between consecutive rising edge times)
        periods = np.diff(rising_times)

        # Compute frequency (inverse of period)
        if len(periods) > 0:
            avg_period = np.mean(periods)
            frequency = 1 / avg_period
            print(f"Extracted Frequency: {frequency:.3f} Hz")

            edges = []
            start = 0
            for i in range(1, len(voltage)):
                if voltage[i-1] < threshold and voltage[i] >= threshold:
                    start = time[i]
                elif voltage[i-1] >= threshold and voltage[i] < threshold:
                    end = time[i]
                    if start>0:
                        edges.append((start, end))
            duty_cycles = []
            for i in range(1, len(edges)):
                high_start, high_end = edges[i-1]
                next_high_start, _ = edges[i]
                period = next_high_start - high_start
                high_time = high_end - high_start
                if period > 0:
                    duty = (high_time / period) * 100
                    duty_cycles.append(duty)
            avg_duty = sum(duty_cycles) / len(duty_cycles) if duty_cycles else 0.0
            print(f"Extracted Duty Cycle: {avg_duty:.3f}")

        else:
            print("No oscillation detected!")
            frequency = -1
            avg_duty =0

        if frequency > -1 and 30<avg_duty and avg_duty<70:
            avg_power_list = []
            for run in range(len(data['VDC_VDD:p'])):
                interp_current = interp1d_no_nan(data['time'], data['VDC_VDD:p'][run])
                times = [avg_period*2]
                delta_t = []
                values = [interp_current(avg_period*2)]
                for t, i in zip ( data['time'], data['VDC_VDD:p'][run]):
                    if t> times[-1]:
                        delta_t.append(t-times[-1])
                        times.append(t)
                        values.append(abs(i))                    
                        
                avg_power = data['VDD'][0][0]*np.sum(np.array(delta_t)*np.array(values[:-1]))/(times[-1]-times[0])
                avg_power_list.append(avg_power)
            
            if Ring_OscillatorMM.check_functionality(tbm, data, 'OUT3') and \
                Ring_OscillatorMM.check_functionality(tbm, data, 'OUT2') and \
                Ring_OscillatorMM.check_functionality(tbm, data, 'OUT1'):
                result = dict(freq=frequency,
                        avg_power=np.array(avg_power_list),
                        avg_duty=avg_duty)
                return MeasInfo('done', result)
        
        # no frequency found or functionality check failed, return the following
        result = dict(freq=np.array([-1]),
                      avg_power=np.array([1]),
                      avg_duty=np.array([-1]))
        return MeasInfo('done', result)
