from PySpice.Spice.Netlist import Circuit
from PySpice.Unit import *
import math
import numpy as np
import matplotlib.pyplot as plt
from .loader.datasets.analogenie_dataset import node2pins, eletric_nodes

mult = {
    '0': 0.1,
    '1': 10,
    '2': 0.1,
    '3': -0.1,
    '4': 0.1,
    '5': -0.1,
    '8': 1,
    '9': 1
}
# 4 & 5: in & out are reversed

class simulation:

    def __init__(self, graph, features=None, with_pins=False, compute_features=True):
        self.graph = graph
        self.with_pins = with_pins
        self.compute_features = compute_features
        self.build_netlist(features)
        self.simulate()

    def print_circuit(self):
        print(self.circuit)

    def build_netlist(self, features):
        # Requires directed edges
        
        # Allow to provide a different set of features in the case where the simulation fails; alternatively,
        # setting features to 'default' results in a simulation with default vertex parameter values 
        if features is None:
            features = self.graph.vs['feat']
        default = True if (type(features) == str) and (features == 'default') else False
        
        n_elt = range(self.graph.vcount())
        types = self.graph.vs['type']
    
        # Identify I/O nodes
        in_idx, out_idx = ([i for i in n_elt if types[i] == 8][0],
                           [i for i in n_elt if types[i] == 9][0])
    
    
        # Netlist construction
        circuit = Circuit('Op-Amp')
        source = circuit.SinusoidalVoltageSource('in', 'IN', circuit.gnd, amplitude=1@u_mV, frequency=1@u_kHz, ac_magnitude=1@u_mV)
        
        for node_idx, node_type in enumerate(types):
            if node_type in [6, 7, 8, 9, 10]:
                continue
            
            ### Up to date code version is with_pins = True, the rest is legacy code (no longer need for node equivalences)
            
                
            pins = self.graph.neighbors(node_idx)
            idx_first_pin, idx_sec_pin = pins[0], pins[1]
            assert len(pins) == 2, 'Wrong number of pin connections'
            in_node, out_node = idx_first_pin, idx_sec_pin

            if node_type in [2, 3, 4, 5]:

                if self.with_pins:
                    # Verify that each gm is located between one input pin and one output pin
                    type_first_pin, type_sec_pin = types[idx_first_pin], types[idx_sec_pin]
                    assert set([type_first_pin, type_sec_pin]) == set([6, 7])
                    inpt_pin = idx_first_pin if type_first_pin == 6 else idx_sec_pin
                    out_pin = idx_first_pin if type_first_pin == 7 else idx_sec_pin

                    in_node = [n for n in self.graph.neighbors(inpt_pin) if n != node_idx][0]
                    out_node = [n for n in self.graph.neighbors(out_pin) if n != node_idx][0]
                else:
                    if (pins[0] == in_idx) or (pins[1] == out_idx):
                        idx_first_pin, idx_sec_pin = pins[0], pins[1]
                    elif (pins[0] == out_idx) or (pins[1] == in_idx):
                        idx_first_pin, idx_sec_pin = pins[1], pins[0]
                    else:
                        # In this case the input and output nodes to a gm are determined by their spd to circuit i/o nodes
                        distances = self.graph.distances(source=[in_idx, out_idx], target=[pins[0], pins[1]], mode='all')
                        is_first_pin_inpt = (distances[0][0] <= distances[0][1]) & (distances[1][0] >= distances[1][1])
                        idx_first_pin = pins[0] if is_first_pin_inpt else pins[1]
                        idx_sec_pin = pins[1] if is_first_pin_inpt else pins[0]
                    in_node, out_node = idx_first_pin, idx_sec_pin

                    if node_type in [4, 5]:
                        in_node, out_node = idx_sec_pin, idx_first_pin    

            if in_node == in_idx:
                in_node = 'IN'
            elif in_node == out_idx:
                in_node = 'OUT'
            else:
                in_node = f'net_{in_node}'

            if out_node == in_idx:
                out_node = 'IN'
            elif out_node == out_idx:
                out_node = 'OUT'
            else:
                out_node = f'net_{out_node}'
            ##

            if not default:
                feat = np.round(mult[str(node_type)] * features[node_idx], 1)
            
            if node_type == 0:
                circuit.R(node_idx, in_node, out_node, mult[str(node_type)] * 1 @u_MΩ if default else feat@u_MΩ)
            elif node_type == 1:
                if not default:
                    feat = np.round(1e-3 * feat, 3)
                circuit.C(node_idx, in_node, out_node, mult[str(node_type)] * 1 @u_pF if default else feat@u_pF)
            else:
                # Default value for VCCS st to 5mS, otherwise failing to provide a value returns an error
                circuit.VCCS(node_idx, out_node, circuit.gnd, in_node, circuit.gnd, mult[str(node_type)] * 10 @u_mS if default else feat@u_mS)
                circuit.C(f'{str(node_idx)}_gnd', out_node, circuit.gnd, 0.05@u_pF)
                circuit.R(f'{str(node_idx)}_gnd', out_node, circuit.gnd, 1@u_MΩ)
    
        self.circuit = circuit

    def simulate(self):
        
        simulator = self.circuit.simulator(temperature=25, nominal_temperature=25)
        self.analysis = simulator.ac(start_frequency=1@u_Hz, stop_frequency=100@u_GHz, number_of_points=100,  variation='dec')

        self.gain = 20 * np.log10(np.absolute(self.analysis.out)) - 20 * np.log10(np.absolute(self.analysis['in']))
        self.phase = np.angle(self.analysis.out, deg=False)
        
        if self.compute_features:
            bw_idx = np.where((self.gain[0] - self.gain) > 3)[0][0]
            self.bw = self.analysis.frequency[bw_idx]
            
            unit_idx = np.where(self.gain < 0)[0][0]
            self.pm = self.phase[unit_idx] * 180 / np.pi + 180
            self.ugw = self.analysis.frequency[unit_idx]

            # Total phase lag
            mod_diff_thresh = 6
            diff = np.diff(self.phase[:unit_idx])
            self.mod = (diff < -mod_diff_thresh).sum() - (diff > mod_diff_thresh).sum()


    def plot_bode(self):
        ##
        # Originates from PySpice documentation
        ##

        figure, (ax1, ax2) = plt.subplots(2, figsize=(20, 10))

        res = 10
        
        plt.title("Bode Diagram of an Operational Amplifier")
        frequency = self.analysis.frequency
        
        ax1.semilogx(frequency[::res], self.gain[::res], marker='.', color='blue', linestyle='-')
        ax1.grid(True)
        ax1.grid(True, which='minor')
        ax1.set_xlabel("Frequency [Hz]")
        ax1.set_ylabel("Gain [dB]")
        
        ax2.semilogx(frequency[::res], self.phase[::res], marker='.', color='blue', linestyle='-')
        ax2.set_ylim(-math.pi, math.pi)
        ax2.grid(True)
        ax2.grid(True, which='minor')
        ax2.set_xlabel("Frequency [Hz]")
        ax2.set_ylabel("Phase [rads]")
        plt.yticks((-math.pi, -math.pi/2,0, math.pi/2, math.pi),
                      (r"$-\pi$", r"$-\frac{\pi}{2}$", "0", r"$\frac{\pi}{2}$", r"$\pi$"))
        
        plt.tight_layout()
        plt.show()



class EqElt:
    def __init__(self, elt_name):
        self.eq = set([elt_name])
        self.first_elt = elt_name
    def eqto(self, b):
        self.eq = self.eq.union(b.eq)
    @property
    def default_name(self):
        if 'IN' in self.eq:
            return 'IN'
        if 'OUT' in self.eq:
            return 'OUT'
        else:
            return self.first_elt

class EqClass:
    def __init__(self):
        self.all_eqs = {}

    def add_new_eq(self, a, b):
        if (a in self.all_eqs.keys()) and (b in self.all_eqs.keys()):
            self.all_eqs[a].eqto(self.all_eqs[b])
        elif a in self.all_eqs.keys():
            self.all_eqs[a].eqto(EqElt(b))
        else:
            if b not in self.all_eqs.keys():
                self.all_eqs[b] = EqElt(b)
            self.all_eqs[b].eqto(EqElt(a))
            self.all_eqs[a] = self.all_eqs[b]
        self.all_eqs[b] = self.all_eqs[a]
    @property
    def equivalences(self):
        return {k: v.default_name for k, v in self.all_eqs.items()}



#####
### AnalogGenie methods -> Netlist are directly built as txt files
#####

def format_netlist(netlist, ports, filename):
    
    ## --- Identify output nodes (VOUT or net)
    output_nodes = [p for p in ports if ('VOUT' in p) or ('net' in p)]
    
    # --- Find MOS nodes ---
    mos_nodes = set()
    for k, line in netlist.items():
        if k.startswith('M'):
            mos_nodes.update(line[:4])

    # --- Netlist construction ---
    lines = [f"* Circuit {filename}\n"]

    # Models
    models = '.model nmos4 NMOS (LEVEL=1 VTO=0.7 KP=120u LAMBDA=0.1 GAMMA=0.5)\n' + \
    '.model pmos4 PMOS (LEVEL=1 VTO=-0.7 KP=60u LAMBDA=0.1 GAMMA=0.5)\n' + \
    '.model npn NPN (BF=100 IS=1e-16 VAF=100 IKF=1m)\n' + \
    '.model pnp PNP (BF=100 IS=1e-16 VAF=100 IKF=1m)\n' + \
    '.model diode D (IS=1e-14 N=1.0 RS=10)\n\n'
    lines.append(models)

    # Sources for all ports present in the netlist
    ports_dc_dict = {'VDD': 1.8, 'VSS': 0, 'VB': 0.9, 'VCONT': 0.9, 'VCM': 0.9, 'VREF': 0.6, 'VIF': 0.5, 'VLO': 0.3, 'VRF': 0.8}
    for port in ports:
        p = port.upper()
        # Digital inputs or clock / Bias and references
        if ('VIN' in port) or ('VCLK' in port):
            lines.append(f"{p} {p} 0 PULSE(0 1.8 0 1n 1n 10n 20n)\n")
        for (k, v) in ports_dc_dict.items():
            if k in p:
                lines.append(f"{p} {p} 0 DC {v}\n")
        # Currents
        if ('IB' in port) or ('IIN' in port) or ('IOUT' in port) or ('IREF' in port):
            lines.append(f"{p} 0 {p} DC 10u\n")

    # Pull-down resistors
    for node in mos_nodes:
        if 'net' in node:
            lines.append(f"RLEAK_{node} {node} 0 1G\n")

    # All components
    for k, line in netlist.items():
        if k.startswith('R'):
            line.append("10k")
        elif k.startswith('C'):
            line.append("1p")
        elif k.startswith('L'):
            line.append("1n")
        line = f"{k} " + " ".join(str(part) for part in line) + "\n"
        lines.append(line)
    
    # Netlist closing
    lines.append(".end\n")

    # Minimal transient analysis
    lines.append(".tran 1p 50n\n")

    # Print commands for output_nodes
    if output_nodes:
        lines.append(".control\n" + "run\n")
        for node in output_nodes:
            lines.append(f"print v({node.lower()})\n")
        lines.append("quit\n" + ".endc\n")

    return ''.join(lines)


spice_mapping = {
    'PM': ('pmos4', 'M'),
    'NM': ('nmos4', 'M'),
    'NPN': ('npn', 'Q'),
    'PNP': ('pnp', 'Q'),
    'R': ('resistor', 'R'),
    'C': ('capacitor', 'C'),
    'L': ('inductor', 'L'),
    'DIO': ('diode', 'D'),
    'XOR': ('XOR', 'X'),
    'PFD': ('PFD', 'X'),
    'INVERTER': ('INVERTER', 'X'),
    'TRANSMISSION_GATE': ('TRANSMISSION_GATE', 'X')
}


def parse_graph(graph):
    
    # Assign names to nodes in order to distinguish node of the same type: NM1, NM2 ... 
    counter = {}
    for node in graph.vs:
        spice_type = spice_mapping.get(node["type"], [node["type"]])[-1]
        node["name"] = spice_type + str(counter.get(spice_type, 0))
        counter[spice_type] = counter.get(spice_type, 0) + 1
        
    netlist = {}
    ports = []
    for node in graph.vs:
        
        if node["type"] in node2pins:
            pins = graph.neighbors(node.index)
            netlist[node["name"]] = [""] * len(node2pins[node["type"]]) + [spice_mapping[node["type"]][0]]

            for pin in pins:
                connections = graph.neighbors(pin)
                connections.remove(node.index)
                
                if len(connections) > 1 : # Invalid circuit - a pin cannot have more than 2 neighbors (incl. parent node)
                    raise NotImplementedError("ERROR more than 1 connection for a pin")
                    
                connection = connections[0]
                
                if graph.vs[connection]["type"] not in eletric_nodes: 
                    raise NotImplementedError("ERROR wrong pin connection (net or port required)")
                    
                index = node2pins[node["type"]].index(graph.vs[pin]["type"])
                netlist[node["name"]][index] = graph.vs[connection]["name"] 

        elif node["type"] in ["C", "L", "R"]:
            
            connections = graph.neighbors(node.index)
            
            if len(connections) > 2 : # Invalid circuit - a pin cannot have more than 2 neighbors
                raise NotImplementedError("ERROR more than 2 connections for R, L or C")
                
            netlist[node["name"]] = ["", ""]
            for j, connection in enumerate(connections):
                
                if graph.vs[connection]["type"] not in eletric_nodes: 
                    raise NotImplementedError("ERROR wrong node connection (net or port required)")
                    
                netlist[node["name"]][j] = graph.vs[connection]["name"]

        elif node["type"] in eletric_nodes:
            ports.append(node["name"])
            
    return netlist, ports