import numpy as np
import copy
import matplotlib.pyplot as plt
import pickle
import yaml

from pathlib import Path
from typing import Any, Union, Tuple, Optional, Mapping, cast, Dict, Type

from x.simulation.core import TestbenchManager, swp_info_from_struct
from x.simulation.data import AnalysisType
from x.simulation.cache import SimulationDB, DesignInstance, SimResults, MeasureResult
from x.simulation.measure import MeasurementManager, MeasInfo
from x.math.interpolate import LinearInterpolator
from x.io.file import write_yaml
from x.concurrent.util import GatherHelper

from x_testbenches.measurement.data.tran import EdgeType
from x_testbenches.measurement.ac.base import ACTB
from x_testbenches.measurement.dc.base import DCTB
from x_testbenches.measurement.base import GenericTB

from collections import OrderedDict

class unified_opa_tb_MM(MeasurementManager):

    # SECTION: Boilerplate
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self._tbm_info: Optional[Tuple[ACTB, Mapping[str, Any]]] = None

    def initialize(self, sim_db: SimulationDB, dut: DesignInstance) -> Tuple[bool, MeasInfo]:
        raise RuntimeError('Unused')

    def get_sim_info(self, sim_db: SimulationDB, dut: DesignInstance, cur_info: MeasInfo
                     ) -> Tuple[Union[Tuple[TestbenchManager, Mapping[str, Any]],
                                      MeasurementManager], bool]:
        raise RuntimeError('Unused')

    def process_output(self, cur_info: MeasInfo, sim_results: Union[SimResults, MeasureResult]
                       ) -> Tuple[bool, MeasInfo]:
        raise RuntimeError('Unused')

    def setup_tbm(self, sim_db: SimulationDB, dut: DesignInstance):
        specs = self.specs
        tbm_specs = copy.deepcopy(dict(**specs['tbm_specs']))
        tbm_specs['dut_pins'] = list(dut.pins.keys()) # Collect a list of dut pins
        
        ac_tbm_specs = copy.deepcopy(tbm_specs)
        ac_tbm_specs['sweep_options'] = specs['ac_sweep_options']['sweep_info']
        ac_tbm = cast(ACTB, sim_db.make_tbm(ACTB, ac_tbm_specs))
        
        
        dc_tbm_specs = copy.deepcopy(tbm_specs)
        dc_tbm_specs['sweep_options'] = specs['dc_sweep_options']['sweep_info']
        dc_tbm_specs['sweep_var'] = specs['dc_sweep_options']['sweep_var']
        dc_tbm = cast(DCTB, sim_db.make_tbm(DCTB, dc_tbm_specs))

        return ac_tbm, dc_tbm


    @staticmethod
    async def _run_sim(name: str, sim_db: SimulationDB, sim_dir: Path, dut: DesignInstance,
                       tbm: GenericTB):
        sim_id = f'{name}'
        sim_results = await sim_db.async_simulate_tbm_obj(sim_id, sim_dir / sim_id,
                                                          dut, tbm, {}, tb_name=sim_id)

        return sim_results

    # SECTION: Action Item
    async def async_measure_performance(self,
                                        name: str,
                                        sim_dir: Path,
                                        sim_db: SimulationDB,
                                        dut: Optional[DesignInstance]) -> Mapping[str, Any]:

        ac_tbm, dc_tbm = self.setup_tbm(sim_db, dut)

        # SECTION: DC 
        result_sim_dc = await self._run_sim(f'{name}_dc', sim_db, sim_dir, dut, dc_tbm)
        data = cast(SimResults, result_sim_dc).data
        current = data['VDC_VDD:p'].squeeze()
        vdd = data['VDD'].squeeze()
        dc_power = abs(current * vdd)

        # SECTION: AC
        # Run the AC simulation.
        result_sim = await self._run_sim(name, sim_db, sim_dir, dut, ac_tbm)
        data = cast(SimResults, result_sim).data

        # Preprocess the outputs
        freq = data['freq']
        voutp = data['VOUTP']
        voutn = data['VOUTN']
        vinp = data['VINP']
        vinn = data['VINN']
        vout_diff = voutp - voutn
        vin_diff = vinp - vinn
        
        # Calculate the gain
        gain = np.where(np.abs(vin_diff) > 1e-12, vout_diff / vin_diff, 0)
        gain = gain.squeeze()
        gain = np.abs(gain)
        gain_db = 20 * np.log10(gain + 1e-20)
        gain_db = gain_db.squeeze()

        gain_n = np.where(np.abs(vin_diff) > 1e-12, abs(voutn) / vin_diff, 0)
        gain_p = np.where(np.abs(vin_diff) > 1e-12, abs(voutp) / vin_diff, 0)
        gain_difference = np.max(abs(gain_p - gain_n))
        print("gain_difference", gain_n, gain_p, gain_difference)


        # Get the DC gain
        dc_gain = float(gain[0])
        dc_gain_db = float(gain_db[0])
        
        # Get 3dB bandwidth
        bw_idx = np.where(gain_db <= dc_gain_db - 3)[0]
        if bw_idx.size > 0:
            bw_3db = float(freq[bw_idx[0]])
        else:
            bw_3db = float('nan') 
        
        # Get UGBW
        ugbw_idx = np.where(gain <= 1.0)[0]
        if ugbw_idx.size > 0:
            ugbw = float(freq[ugbw_idx[0]])
        else:
            ugbw = float('nan')
        
        if gain_difference > 1.75 or dc_gain<0.5: 
            result_dict = {
            'dc_power': float(dc_power),
            'Ao': float(dc_gain),
            'Ao_db': float(-1),
            'bw_3db': float(bw_3db),
            'bw_ug': float(ugbw),
            'gain_difference': float(gain_difference)
            }
        else:
            result_dict = {
                'dc_power': float(dc_power),
                'Ao': float(dc_gain),
                'Ao_db': float(dc_gain_db),
                'bw_3db': float(bw_3db),
                'bw_ug': float(ugbw),
                'gain_difference': float(gain_difference)
            }

        results_path = Path(self.specs['result_file'])
        with open(results_path, 'w') as f:
            yaml.dump(result_dict, f)
            #pickle.dump(result_dict, f)
        

    def plot_ac_response(self, freq: np.ndarray, gain_db: np.ndarray, ugbw: float = None, bw_3db: float = None) -> None:
    
        plt.figure(figsize=(8, 5))
        plt.semilogx(freq, gain_db, label='AC Gain (dB)')
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Gain (dB)')
        plt.title('AC Response')
        plt.grid(True, which='both', linestyle='--', alpha=0.7)
    
        # Optional: mark UGBW and 3dB BW
        if ugbw is not None and not np.isnan(ugbw):
            plt.axvline(x=ugbw, color='r', linestyle='--', label=f'UGBW: {ugbw:.2e} Hz')
        if bw_3db is not None and not np.isnan(bw_3db):
            plt.axvline(x=bw_3db, color='g', linestyle='--', label=f'BW -3dB: {bw_3db:.2e} Hz')
    
        plt.legend()
        plt.tight_layout()
        plt.savefig('test_ac_response.png')

