from PySpice.Spice.Netlist import Circuit
from PySpice.Unit import *
import math
import numpy as np
import matplotlib.pyplot as plt

mult = {
    '0': 0.1,
    '1': 10,
    '2': 0.1,
    '3': -0.1,
    '4': 0.1,
    '5': -0.1,
    '8': 1,
    '9': 1
}
# 4 & 5: in & out are reversed

class simulation:

    def __init__(self, graph, features=None, with_pins=False, compute_features=True):
        self.graph = graph
        self.with_pins = with_pins
        self.compute_features = compute_features
        self.build_netlist(features)
        self.simulate()

    def print_circuit(self):
        print(self.circuit)

    def build_netlist(self, features):
        # Requires directed edges
        
        # Allow to provide a different set of features in the case where the simulation fails; alternatively,
        # setting features to 'default' results in a simulation with default vertex parameter values 
        if features is None:
            features = self.graph.vs['feat']
        default = True if (type(features) == str) and (features == 'default') else False
        
        n_elt = range(self.graph.vcount())
        types = self.graph.vs['type']
    
        # Identify I/O nodes
        in_idx, out_idx = ([i for i in n_elt if types[i] == 8][0],
                           [i for i in n_elt if types[i] == 9][0])
    
        # # Identify common device I/O nodes of the circuit 
        # edge_list = self.graph.get_edgelist()
        # eq_nodes = EqClass()
        # # First identify net outputs with inputs
        # net_indices = np.where(np.array(types) == 10)[0].tolist()
        # for net_id in net_indices:
        #     eq_nodes.add_new_eq(f'{str(net_id)}_inpt', f'net_{net_id}')
        #     eq_nodes.add_new_eq(f'{str(net_id)}_out', f'net_{net_id}')
        # for in_node, out_node in edge_list:
        #     if in_node == in_idx:
        #         eq_nodes.add_new_eq('IN', f'{str(out_node)}_inpt')
        #     elif out_node == in_idx:
        #         eq_nodes.add_new_eq('IN', f'{str(in_node)}_out')
        #     elif in_node == out_idx:
        #         eq_nodes.add_new_eq('OUT', f'{str(out_node)}_inpt')
        #     elif out_node == out_idx:
        #         eq_nodes.add_new_eq('OUT', f'{str(in_node)}_out')
        #     else:
        #         eq_nodes.add_new_eq(f'{str(out_node)}_inpt', f'{str(in_node)}_out')
    
        # Netlist construction
        circuit = Circuit('Op-Amp')
        source = circuit.SinusoidalVoltageSource('in', 'IN', circuit.gnd, amplitude=1@u_mV, frequency=1@u_kHz, ac_magnitude=1@u_mV)
        
        for node_idx, node_type in enumerate(types):
            if node_type in [6, 7, 8, 9, 10]:
                continue
            
            ### Up to date code version is with_pins = True, the rest is legacy code (no longer need for node equivalences)
            
                
            pins = self.graph.neighbors(node_idx)
            idx_first_pin, idx_sec_pin = pins[0], pins[1]
            assert len(pins) == 2, 'Wrong number of pin connections'
            in_node, out_node = idx_first_pin, idx_sec_pin

            if node_type in [2, 3, 4, 5]:

                if self.with_pins:
                    # Verify that each gm is located between one input pin and one output pin
                    type_first_pin, type_sec_pin = types[idx_first_pin], types[idx_sec_pin]
                    assert set([type_first_pin, type_sec_pin]) == set([6, 7])
                    inpt_pin = idx_first_pin if type_first_pin == 6 else idx_sec_pin
                    out_pin = idx_first_pin if type_first_pin == 7 else idx_sec_pin

                    in_node = [n for n in self.graph.neighbors(inpt_pin) if n != node_idx][0]
                    out_node = [n for n in self.graph.neighbors(out_pin) if n != node_idx][0]
                else:
                    if (pins[0] == in_idx) or (pins[1] == out_idx):
                        idx_first_pin, idx_sec_pin = pins[0], pins[1]
                    elif (pins[0] == out_idx) or (pins[1] == in_idx):
                        idx_first_pin, idx_sec_pin = pins[1], pins[0]
                    else:
                        # In this case the input and output nodes to a gm are determined by their spd to circuit i/o nodes
                        distances = self.graph.distances(source=[in_idx, out_idx], target=[pins[0], pins[1]], mode='all')
                        is_first_pin_inpt = (distances[0][0] <= distances[0][1]) & (distances[1][0] >= distances[1][1])
                        idx_first_pin = pins[0] if is_first_pin_inpt else pins[1]
                        idx_sec_pin = pins[1] if is_first_pin_inpt else pins[0]
                    in_node, out_node = idx_first_pin, idx_sec_pin

                    if node_type in [4, 5]:
                        in_node, out_node = idx_sec_pin, idx_first_pin    

            if in_node == in_idx:
                in_node = 'IN'
            elif in_node == out_idx:
                in_node = 'OUT'
            else:
                in_node = f'net_{in_node}'

            if out_node == in_idx:
                out_node = 'IN'
            elif out_node == out_idx:
                out_node = 'OUT'
            else:
                out_node = f'net_{out_node}'
            ##

            # elif self.with_pins and node_type in [2, 3, 4, 5] and self.version == 1:

            #     idx_first_pin = [int(k.replace('_out', '')) for k in eq_nodes.all_eqs[f'{str(node_idx)}_inpt'].eq if k !=f'{str(node_idx)}_inpt'][0]
            #     idx_sec_pin = [int(k.replace('_inpt', '')) for k in eq_nodes.all_eqs[f'{str(node_idx)}_out'].eq if k !=f'{str(node_idx)}_out'][0]

            #     # Verify that each gm is located between one input pin and one output pin
            #     type_first_pin = types[idx_first_pin]
            #     type_sec_pin = types[idx_sec_pin]
            #     assert (type_first_pin == 6 and type_sec_pin == 7) or (type_first_pin == 7 and type_sec_pin == 6)
            #     inpt_pin = idx_first_pin if type_first_pin == 6 else idx_sec_pin
            #     inpt_key = 'inpt' if type_first_pin == 6 else 'out'
            #     out_pin = idx_first_pin if type_first_pin == 7 else idx_sec_pin
            #     out_key = 'out' if type_first_pin == 6 else 'inpt'

            #     in_node, out_node = eq_nodes.equivalences[f'{str(inpt_pin)}_{inpt_key}'], eq_nodes.equivalences[f'{str(out_pin)}_{out_key}']

            # elif node_type not in [4, 5]:
            #     in_node, out_node = eq_nodes.equivalences[f'{str(node_idx)}_inpt'], eq_nodes.equivalences[f'{str(node_idx)}_out']
            # else:
            #     in_node, out_node = eq_nodes.equivalences[f'{str(node_idx)}_out'], eq_nodes.equivalences[f'{str(node_idx)}_inpt']

            if not default:
                feat = np.round(mult[str(node_type)] * features[node_idx], 1)
            
            if node_type == 0:
                circuit.R(node_idx, in_node, out_node, mult[str(node_type)] * 1 @u_MΩ if default else feat@u_MΩ)
            elif node_type == 1:
                if not default:
                    feat = np.round(1e-3 * feat, 3)
                circuit.C(node_idx, in_node, out_node, mult[str(node_type)] * 1 @u_pF if default else feat@u_pF)
            else:
                # Default value for VCCS st to 5mS, otherwise failing to provide a value returns an error
                circuit.VCCS(node_idx, out_node, circuit.gnd, in_node, circuit.gnd, mult[str(node_type)] * 10 @u_mS if default else feat@u_mS)
                circuit.C(f'{str(node_idx)}_gnd', out_node, circuit.gnd, 0.05@u_pF)
                circuit.R(f'{str(node_idx)}_gnd', out_node, circuit.gnd, 1@u_MΩ)
    
        self.circuit = circuit

    def simulate(self):
        
        simulator = self.circuit.simulator(temperature=25, nominal_temperature=25)
        self.analysis = simulator.ac(start_frequency=1@u_Hz, stop_frequency=100@u_GHz, number_of_points=100,  variation='dec')

        self.gain = 20 * np.log10(np.absolute(self.analysis.out)) - 20 * np.log10(np.absolute(self.analysis['in']))
        self.phase = np.angle(self.analysis.out, deg=False)
        
        if self.compute_features:
            bw_idx = np.where((self.gain[0] - self.gain) > 3)[0][0]
            self.bw = self.analysis.frequency[bw_idx]
            
            unit_idx = np.where(self.gain < 0)[0][0]
            self.pm = self.phase[unit_idx] * 180 / np.pi + 180
            self.ugw = self.analysis.frequency[unit_idx]

            # Total phase lag
            mod_diff_thresh = 6
            diff = np.diff(self.phase[:unit_idx])
            self.mod = (diff < -mod_diff_thresh).sum() - (diff > mod_diff_thresh).sum()


    def plot_bode(self):
        ##
        # Originates from PySpice documentation
        ##

        figure, (ax1, ax2) = plt.subplots(2, figsize=(20, 10))

        res = 10
        
        plt.title("Bode Diagram of an Operational Amplifier")
        frequency = self.analysis.frequency
        
        ax1.semilogx(frequency[::res], self.gain[::res], marker='.', color='blue', linestyle='-')
        ax1.grid(True)
        ax1.grid(True, which='minor')
        ax1.set_xlabel("Frequency [Hz]")
        ax1.set_ylabel("Gain [dB]")
        
        ax2.semilogx(frequency[::res], self.phase[::res], marker='.', color='blue', linestyle='-')
        ax2.set_ylim(-math.pi, math.pi)
        ax2.grid(True)
        ax2.grid(True, which='minor')
        ax2.set_xlabel("Frequency [Hz]")
        ax2.set_ylabel("Phase [rads]")
        plt.yticks((-math.pi, -math.pi/2,0, math.pi/2, math.pi),
                      (r"$-\pi$", r"$-\frac{\pi}{2}$", "0", r"$\frac{\pi}{2}$", r"$\pi$"))
        
        plt.tight_layout()
        plt.show()

    # NEW METHODS ADDED BELOW

    def get_default_features(self):
        """
        Generate default features for each node type based on typical circuit values.
        Returns a list of default features matching the graph structure.
        """
        default_values = {
            0: 1.0,    # Resistor: 1 unit (will be scaled by mult['0'] = 0.1 -> 0.1 MΩ)
            1: 1.0,    # Capacitor: 1 unit (will be scaled by mult['1'] = 10 -> 10 pF after 1e-3 scaling)
            2: 1.0,    # Transconductor: 1 unit (will be scaled by mult['2'] = 0.1 -> 0.1 mS)
            3: 1.0,    # Negative transconductor: 1 unit (will be scaled by mult['3'] = -0.1 -> -0.1 mS)
            4: 1.0,    # Reversed transconductor: 1 unit (will be scaled by mult['4'] = 0.1 -> 0.1 mS)
            5: 1.0,    # Reversed negative transconductor: 1 unit (will be scaled by mult['5'] = -0.1 -> -0.1 mS)
            8: 1.0,    # Input node: 1 unit
            9: 1.0     # Output node: 1 unit
        }
        
        return [default_values.get(node_type, 1.0) for node_type in self.graph.vs['type']]

    def build_netlist_default(self):
        """
        Build netlist using default features instead of graph features.
        """
        default_features = self.get_default_features()
        self.build_netlist(default_features)

    def simulate_default(self):
        """
        Build netlist with default features and simulate the circuit.
        """
        self.build_netlist_default()
        self.simulate()

    def reset_and_simulate_default(self):
        """
        Reset the circuit and simulate with default features.
        Useful when you want to compare default vs custom features.
        """
        # Store original features if they exist
        if hasattr(self, 'original_features'):
            pass  # Already stored
        else:
            try:
                self.original_features = self.graph.vs['feat'].copy()
            except:
                self.original_features = None
        
        # Simulate with default features
        self.simulate_default()

    def restore_original_simulation(self):
        """
        Restore and simulate with original features (if they were stored).
        """
        if hasattr(self, 'original_features') and self.original_features is not None:
            self.build_netlist(self.original_features)
            self.simulate()
        else:
            print("No original features stored. Using graph features.")
            self.build_netlist(None)  # Will use self.graph.vs['feat']
            self.simulate()

    def compare_default_vs_original(self):
        """
        Compare simulation results between default features and original features.
        Returns a dictionary with comparison metrics.
        """
        # Store current state
        current_gain = self.gain.copy() if hasattr(self, 'gain') else None
        current_phase = self.phase.copy() if hasattr(self, 'phase') else None
        current_bw = self.bw if hasattr(self, 'bw') else None
        current_pm = self.pm if hasattr(self, 'pm') else None
        current_ugw = self.ugw if hasattr(self, 'ugw') else None
        
        # Simulate with default features
        self.simulate_default()
        default_results = {
            'gain': self.gain.copy(),
            'phase': self.phase.copy(),
            'bw': self.bw,
            'pm': self.pm,
            'ugw': self.ugw
        }
        
        # Simulate with original features
        self.restore_original_simulation()
        original_results = {
            'gain': self.gain.copy(),
            'phase': self.phase.copy(),
            'bw': self.bw,
            'pm': self.pm,
            'ugw': self.ugw
        }
        
        # Restore current state if it existed
        if current_gain is not None:
            self.gain = current_gain
            self.phase = current_phase
            self.bw = current_bw
            self.pm = current_pm
            self.ugw = current_ugw
        
        return {
            'default': default_results,
            'original': original_results,
            'frequency': self.analysis.frequency
        }

    def build_netlist_no_sizing(self):
        """
        Build netlist with minimal/generic component values instead of completely omitting them.
        This avoids PySpice issues with unspecified component values while still using
        generic sizing that doesn't depend on the graph features.
        """
        n_elt = range(len(self.graph.vs['type']))
    
        # Identify I/O nodes
        in_idx, out_idx = ([i for i in n_elt if self.graph.vs['type'][i] == 8][0],
                           [i for i in n_elt if self.graph.vs['type'][i] == 9][0])
    
        # Identify common device I/O nodes of the circuit 
        edge_list = self.graph.get_edgelist()
        eq_nodes = EqClass()
        for in_node, out_node in edge_list:
            if in_node == in_idx:
                eq_nodes.add_new_eq('IN', f'{str(out_node)}_inpt')
            elif out_node == out_idx:
                eq_nodes.add_new_eq('OUT', f'{str(in_node)}_out')
            else:
                eq_nodes.add_new_eq(f'{str(out_node)}_inpt', f'{str(in_node)}_out')
    
        # Netlist construction with minimal generic values
        circuit = Circuit('Op-Amp')
        source = circuit.SinusoidalVoltageSource('in', 'IN', circuit.gnd, amplitude=1@u_mV, frequency=1@u_kHz, ac_magnitude=1@u_mV)
        
        # Generic component values (not dependent on graph features)
        generic_values = {
            0: 1@u_kΩ,      # Generic resistor value
            1: 1@u_pF,      # Generic capacitor value
            2: 1@u_mS,      # Generic transconductor value
            3: -1@u_mS,     # Generic negative transconductor value
            4: 1@u_mS,      # Generic reversed transconductor value
            5: -1@u_mS      # Generic reversed negative transconductor value
        }
        
        for node_idx, node_type in enumerate(self.graph.vs['type']):
            if node_idx in [in_idx, out_idx]:
                continue
        
            if node_type not in [4, 5]:
                in_node, out_node = eq_nodes.equivalences[f'{str(node_idx)}_inpt'], eq_nodes.equivalences[f'{str(node_idx)}_out']
            else:
                in_node, out_node = eq_nodes.equivalences[f'{str(node_idx)}_out'], eq_nodes.equivalences[f'{str(node_idx)}_inpt']
            
            # Create components with generic values
            if node_type == 0:  # Resistor
                circuit.R(node_idx, in_node, out_node, generic_values[0])
            elif node_type == 1:  # Capacitor
                circuit.C(node_idx, in_node, out_node, generic_values[1])
            else:  # Transconductors (types 2, 3, 4, 5)
                circuit.VCCS(node_idx, out_node, circuit.gnd, in_node, circuit.gnd, generic_values[node_type])
                # Add small parasitic capacitance and resistance for stability
                circuit.C(f'{str(node_idx)}_gnd', out_node, circuit.gnd, 0.05@u_pF)
                circuit.R(f'{str(node_idx)}_gnd', out_node, circuit.gnd, 1@u_MΩ)
    
        self.circuit = circuit

    def simulate_no_sizing(self):
        """
        Build netlist without component sizing and simulate the circuit.
        PySpice will use its default component values.
        """
        self.build_netlist_no_sizing()
        self.simulate()


class EqElt:
    def __init__(self, elt_name):
        self.eq = set([elt_name])
        self.first_elt = elt_name
    def eqto(self, b):
        self.eq = self.eq.union(b.eq)
    @property
    def default_name(self):
        if 'IN' in self.eq:
            return 'IN'
        if 'OUT' in self.eq:
            return 'OUT'
        else:
            return self.first_elt

class EqClass:
    def __init__(self):
        self.all_eqs = {}

    def add_new_eq(self, a, b):
        if (a in self.all_eqs.keys()) and (b in self.all_eqs.keys()):
            self.all_eqs[a].eqto(self.all_eqs[b])
        elif a in self.all_eqs.keys():
            self.all_eqs[a].eqto(EqElt(b))
        else:
            if b not in self.all_eqs.keys():
                self.all_eqs[b] = EqElt(b)
            self.all_eqs[b].eqto(EqElt(a))
            self.all_eqs[a] = self.all_eqs[b]
        self.all_eqs[b] = self.all_eqs[a]
    @property
    def equivalences(self):
        return {k: v.default_name for k, v in self.all_eqs.items()}
    
import os, pickle
def resample_sizes_from_graph_idx(dirp):

    ## Load dataset
    path = 'datasets/CktBench101/raw/ckt_bench_101.pkl'
    with open(path, 'rb') as file:
        out = pickle.load(file)

    for key in ['train', 'test']:
        with open(os.path.join(dirp, f'NEW_DS_failed_idx_{key}'), 'rb') as f:
            failed = pickle.load(f)
        first_ds_idx = 0 if key == 'train' else 1

        output_labels = []
        output_graphs = []
        trials = []

        max_tries = 50

        for idx in failed:

            graph = out[first_ds_idx][idx][1]
            trial = 0
            failed_sim = True
            
            while (trial < max_tries) and failed_sim:
                trial += 1
                feats = np.random.randint(1, 100, size=len(graph.vs['feat']))
                try:
                    graph.vs['feat'] = feats
                    sim = simulation(graph)
                    if not (sim.gain > 0).any():
                        continue
                    output_labels.append({
                        'index': idx,
                        'gain': sim.gain[0],
                        'ugw': sim.ugw,
                        'pm': sim.pm})
                    output_graphs.append(graph)
                    failed_sim = False
                    trials.append(trial)
                except:
                    continue

        with open(f'NEW_DS_{key}_graphs_2', 'wb') as f:
            pickle.dump(output_graphs, f)
        with open(f'NEW_DS_{key}_labels_2', 'wb') as f:
            pickle.dump(output_labels, f)

if __name__ == "__main__":
    dirp = '****'
    resample_sizes_from_graph_idx(dirp)