from typing import Dict, Any, List
from mlcirc_pipelines.bayes_opt.bo_utils2 import BayesOpt as BO
from mlcirc_pipelines.bayes_opt.bo_utils2 import CircuitBayesOptConfig
from mlcirc_pipelines.bayes_opt.bo_eval_models2 import MOOPACircuitPerformance
from mlcirc_pipelines.interface_ops import \
                construct_stdgraph_pickle_from_rl_stdgraph_pickle
from mlcirc_utils import StandardGraph, ComponentEnum, EdgeTerminalEnum
from mlcirc_utils.branch_detection import ConnectivityUtils as cu
import torch 
import copy
import time

def make_phys_def():
    return {0: [{"feature":"nf", "bounds":[2,100], "parameter_type": 'INT'}],
            1: [{"feature":"nf", "bounds":[2,100], "parameter_type": 'INT'}], 
            2: [{"feature":"l", "bounds":[1e-12,1], "parameter_type": 'FLOAT'}],
            3: [{"feature":"m", "bounds":[1,10], "parameter_type": 'INT'}],
            5: [],
            6: [{"feature":"bias", "bounds": [0, 1.8], "parameter_type": 'FLOAT'}]}

def infer_parameters_from_stdgraph_pickle(params, standard_graph, mapping, symmetric_pairs=[], num_bias_pins=0):

    params_dict = {}
    graph_nodes = standard_graph.get_nodes()
    num_param = 0
    covered_names = []
    for n in graph_nodes: 
        name = n[0]
        if name in covered_names:
            continue

        name_list = []
        paired = False
        if len(symmetric_pairs) > 0:
            node_num = next((k for k, v in mapping.items() if v == name), None)
            for pair in symmetric_pairs:
                if node_num in pair:
                    name_list = [mapping[pair[0]], mapping[pair[1]]]
                    paired = True
                    break
        else: 
            name_list = [name]

        if (len(symmetric_pairs) > 0) and not paired:
            name_list = [name]
        covered_names.extend(name_list)

        comp_type = graph_nodes[name]['component_type'] 
        if comp_type != ComponentEnum.NET:
            for ps in params[comp_type.value]:
                d = {f'x{num_param}' : {'path': [ps['feature']], 
                                    'component_list': name_list,
                                    'bounds': ps['bounds'],
                                    'parameter_type': ps['parameter_type'] }}
                params_dict.update(d)
                num_param += 1
    for bias in range(num_bias_pins):
        d = {f'x{num_param}': {'variable': f'vb{bias}_val', 
                           'component_list': [f'VB_{bias}'],
                           'bounds': [0, 1.8],
                           'parameter_type': 'FLOAT'}}
        params_dict.update(d)
        num_param += 1

    return params_dict

def supply_connectivity(input_graph):
    """
        Checks if a circuit has continuity from VDD to VSS for 
        all paths.

        In its current iteration, this implies that circuit inputs
        are not valid. 
    """
    # Remove net nodes if they are only connected to gates
    graph = copy.deepcopy(input_graph)
    node_removal = []
    for n in list(graph.get_nodes()):
        if n[1]["component_type"] == ComponentEnum.NET:
            g_only = True
            for e in list(graph.get_edges()):
                if n[0] in e: 
                    attrs = graph.get_edge_features(e[0], e[1], 'one_hot_edge_attr')
                    comp_type = graph.get_node_features(e[0], ["component_type"]) if e[0]!=n[0] \
                                else graph.get_node_features(e[1], ["component_type"])
                    if comp_type==ComponentEnum.NFET or comp_type==ComponentEnum.PFET:
                        if (attrs[2]) and attrs[1]==0 and attrs[3]==0:
                            graph.remove_edge(e[0], e[1])
                        else:
                            g_only = False
                    elif comp_type==ComponentEnum.IND or comp_type==ComponentEnum.RES \
                        or comp_type==ComponentEnum.CAP: # passives don't have ignorable nets 
                        g_only = False
            if g_only:
                try:
                    n_name = n[1]["custom_features"]["node_name"]
                    if n_name != "VDD" and n_name != "VSS":
                        node_removal.append(n[0])
                except:
                    node_removal.append(n[0]) #must be a new net

    # Remove all other connections except D and S connections and M and P connections
    for e in list(graph.get_edges()): 
        attrs = graph.get_edge_features(e[0], e[1], 'one_hot_edge_attr')
        e0_type = graph.get_node_features(e[0], ["component_type"]) 
        if e0_type != ComponentEnum.NET:
            comp_type == e0_type
        else:
            comp_type = graph.get_node_features(e[1], ["component_type"])
        if comp_type==ComponentEnum.NFET or comp_type==ComponentEnum.PFET:
            if attrs[1] ==0 and attrs[3] == 0: 
                graph.remove_edge(e[0], e[1])
        else:
            if attrs[4] ==0 and attrs[5] == 0: 
                graph.remove_edge(e[0], e[1])

    # If there is more than one cluster fail
    # # this can happen if you have floating bulks too, though with env shouldn't see this
    clusters = graph.get_clusters()
    actual_clusters = [
        c for c in clusters 
        if not (len(c) == 1 and c.issubset(node_removal))
    ]
    if len(actual_clusters)!= 1:
        return False

    # DFS from VDD to VSS
    # First find VDD and VSS nodes by node number: 
    for n in list(graph.get_nodes()):
        if n[1]["component_type"] == ComponentEnum.NET:
            try: 
                name=n[1]["custom_features"]["node_name"]
            except:
                name = "NET" +str(n[0])
            
            if name=="VDD": 
                start = n[0]
            elif name == "VSS":
                end = n[0]
    paths = dfs_all_paths(graph, start, end)
    nodes_non_unique = []
    for p in paths:
        nodes_non_unique = nodes_non_unique + p

    # Remove all nodes that are included in paths
    nodes_in_paths = set(nodes_non_unique) 
    for n in nodes_in_paths: 
        graph.remove_node(n)
    for n in node_removal: 
        graph.remove_node(n)

    # if there are any nodes left fail
    if len(graph.get_clusters())!= 0:
        return False

    return True

def dfs_all_paths(graph, start, end, path=None, all_paths=None, visited_paths=None):
    """
    Recursive function to find all paths from start to end using depth-first search.
    Avoids duplicate paths by tracking visited paths in a set.
    """
    if path is None:
        path = []
    if all_paths is None:
        all_paths = []
    if visited_paths is None:
        visited_paths = set()

    path = path + [start]  # Add current node to the path

    # Convert path to a tuple (hashable) and check if it's already visited
    path_tuple = tuple(path)
    if path_tuple in visited_paths:
        return  # Skip duplicate paths

    visited_paths.add(path_tuple)  # Mark path as visited

    if start == end:
        all_paths.append(path)  # Found a valid path
        return

    g_nodes = [n[0] for n in graph.get_nodes()]
    for neighbor_idx in find_neighbors(graph, g_nodes.index(start)):
        neighbor = g_nodes[neighbor_idx]
        if neighbor not in path:  # Avoid revisiting nodes (no cycles)
            dfs_all_paths(graph, neighbor, end, path, all_paths, visited_paths)

    return all_paths

def find_neighbors(graph, start):
    adj = graph.get_adjacency_matrix()
    return [i for i, val in enumerate(adj[start]) if val == 1]

def circuit_symmetry(): 
    raise NotImplementedError

def relabeling_graph(graph, construction_config):

    (std_graph, target_name, pin_order, mapping) = construct_stdgraph_pickle_from_rl_stdgraph_pickle(construction_config, graph)

    target_name = "opa_extra_biases"
    pin_order = ["VOUTN", "VOUTP", "VDD", "VSS", "VB_9", "VB_8", "VB_7", "VB_6", "VB_5", "VB_4", "VB_3",
                 "VB_2", "VB_1", "VB_0", "VINN", "VINP"]
    # determine which pins are bias pins. These are nets that are only connected to gates
    # but are not input/output pins
    node_removal, gate_nonremove, ds_graph  = cu.make_ds_graph(std_graph)

    gate_nodes = []
    for n in node_removal:
        try:
            node_name = graph.get_node_features(n, ["custom_features", "node_name"])
            if node_name not in ["VDD", "VSS"] and not any(
                str(node_name).startswith(prefix)
                for prefix in ["VOUT", "VIN"]
            ):
                gate_nodes.append(n)
        except:
            gate_nodes.append(n)  # If no custom features, assume it's a gate node
    for n in list(std_graph.get_nodes()):
        if n[0] in gate_nodes:
            pass
        else:
            try:
                if n[0] == "NFET10":
                    for e in graph.get_node_edges(n[0]):
                        other_component = e[0] if e[1]==n[0] else e[1]
                        other_name = graph.get_node_features(other_component, ["custom_features", "node_name"])
                        if "VB" in other_name and other_name not in gate_nodes:
                            gate_nodes.append(other_component)
            except: 
                pass

    num_bias_pins = len(gate_nodes)
    bias_num = 0
    map_gate = {}
    for n in gate_nodes:
        map_gate[n] = f'VB_{bias_num}'
        bias_num += 1
        

    for k, v in map_gate.items():
        std_graph.update_node_feature_by_path(node_name=k, 
                                                key_path=["custom_features", "node_name"], 
                                                update_value=v)

    name_map = {}
    for n in std_graph.get_nodes():
        if n[0] in map_gate.keys():
            rename = map_gate[n[0]]
            name_map[n[0]] = rename
            swap = True
        elif n[0] == "VIN1":
            rename = "VINP"
            name_map[n[0]] = rename
            swap = True
        elif n[0] == "VIN2":
            rename = "VINN"
            name_map[n[0]] = rename
            swap = True
        elif n[0] == "VOUT1":
            rename = "VOUTP"
            name_map[n[0]] = rename
            swap = True
        elif n[0] == "VOUT2":
            rename = "VOUTN"
            name_map[n[0]] = rename
            swap = True
        else:
            name_map[n[0]] = n[0]
            swap = False

        if swap:
            for map_k, map_v in mapping.items():
                if map_v == n[0]:
                    mapping[map_k] = rename
                    break

    # Neet to relabel things because specific nodes well not be the default net parameter anymore
    std_graph.relabel_nodes(mapping=name_map)
    for e in list(std_graph.get_edges()):
        node_1_terminal = std_graph.get_edge_features(e[0], e[1], 'node_1_terminal_feature')
        node_2_terminal = std_graph.get_edge_features(e[0], e[1], 'node_2_terminal_feature')
        node_1_terminal_list = [EdgeTerminalEnum(idx) for idx,i in enumerate(node_1_terminal) if i==1]
        node_2_terminal_list = [EdgeTerminalEnum(idx) for idx,i in enumerate(node_2_terminal) if i==1]
        net_terminal = node_1_terminal_list if EdgeTerminalEnum.NET in node_1_terminal_list else node_2_terminal_list
        component_terminal = node_2_terminal_list if EdgeTerminalEnum.NET in node_1_terminal_list else node_1_terminal_list
        net = e[0] if std_graph.get_node_features(e[0], ["component_type"]) == ComponentEnum.NET else e[1]
        component = e[1] if net == e[0] else e[0]

        phys_features = std_graph.get_edge_features(e[0], e[1], 'phys_features')
        custom_features = std_graph.get_edge_features(e[0], e[1], 'custom_features')
        std_graph.remove_edge(e[0],e[1])
        std_graph.add_edge(net, component, net_terminal, component_terminal,
                                phys_features = phys_features, 
                                custom_features = custom_features,)

    return std_graph, target_name, pin_order, mapping, num_bias_pins

def bayes_opt(graph, bayes_opt_config, construction_config, symmetric_pairs = [], custom_str='', save_directory='', dud_pins=0):
    "Dud pins is for cases like ring oscillator, where there are no inputs"
    std_graph, target_name, pin_order, mapping, num_bias_pins = relabeling_graph(graph, construction_config)

    # infer the parameters 
    params = infer_parameters_from_stdgraph_pickle(make_phys_def(), 
                 std_graph, mapping, symmetric_pairs, num_bias_pins=num_bias_pins)
    
    if dud_pins > 0:
        pin_order = pin_order + [f'floating{i}' for i in range(dud_pins)]


    outcomes, experiment, cache = BO.bayes_opt_basic(bayes_opt_config, params, target_name, pin_order, 
                                                        std_graph, N_INIT=5, N_BATCH=1, batch_size=4,
                                                         save_directory=save_directory, custom_str=custom_str)
    pareto_params_and_performance = BO.extract_pareto_results(experiment, bayes_opt_config.specs, params, cache)

    pareto_results = []
    for p in pareto_params_and_performance:
        pareto_results.append(p['measured_performance'])

    return None, pareto_results

def create_tensor_from_params(params_dict, nf_num, l, bias):
    tensor_entries = []
    for key, val in params_dict.items():
        if 'path' in val.keys():
            feature = val['path'][0]
            if feature == 'nf':
                tensor_entries.append(nf_num)
            elif feature == 'l':
                tensor_entries.append(l)
            else:
                tensor_entries.append(1)
        elif 'variable' in val.keys():
            tensor_entries.append(bias)
    return torch.tensor(tensor_entries)

def sim_graph(graph, bayes_opt_config, construction_config, dud_pins=0, symmetric_pairs=[], custom_str='', save_directory=''):
    "For no bayes opt"
    std_graph, target_name, pin_order, mapping, num_bias_pins = relabeling_graph(graph, construction_config)
    params = infer_parameters_from_stdgraph_pickle(make_phys_def(), 
                     std_graph, mapping, symmetric_pairs, num_bias_pins=num_bias_pins)
    
    if dud_pins > 0:
        pin_order = pin_order + [f'floating{i}' for i in range(dud_pins)]
    
    circuit_model = MOOPACircuitPerformance(data_path=bayes_opt_config.base_netlist, 
                                    target_circuit=target_name, 
                                    netlist_pin_order = pin_order,
                                    params_dict=params,
                                    objectives = bayes_opt_config.specs,
                                    base_harness_yaml_file=bayes_opt_config.base_harness_yaml_file,
                                    run_dir=bayes_opt_config.run_dir,
                                    
                                    mlcirc_tech_pdk=bayes_opt_config.mlcirc_tech_pdk,
                                    base_pdk_configuration=bayes_opt_config.base_pdk_configuration,
                                    pdk_tech_plugin=bayes_opt_config.pdk_tech_plugin,
                                    graph=std_graph,
                                    save_directory=save_directory).to(
        dtype=torch.double,
        device=torch.device("cuda:" + str(bayes_opt_config.gpu_id) if torch.cuda.is_available() else "cpu"),
    )

    set_params = create_tensor_from_params(params, 50, 5, 0.9)

    circuit_model.update_graph(set_params)
    single_result = circuit_model.evaluate_default_graph(set_params, custom_str)
    for k, v in single_result.items():
        if abs(v) ==1:
            return None
    return single_result


def domain_reward(results: List[Dict[str, Any]], targets: Dict[str, Any], completion_reward: int=10):
    # look at every spec
    # find the maximum reward along the pareto optima
    # save pareto results as list of Dicts, dict of achieved specs for each design
    reward_list = []

    for r in results:
        for k in list(r.keys()):
            if k not in targets.keys():
                r.pop(k)

    for d in results:
        reward = 0
        opt_reward = []
        num_keys = 0
        for k,v in d.items():
            if targets[k]['exact']: 
                v_targ = targets[k]['target']
                if 'bound' in targets[k].keys():
                    error = abs(v-v_targ)/(targets[k]['bound'])
                else:
                    error = abs(v-v_targ)/(v_targ/2)
                rew = 1-error
                opt_reward.append(rew*15)
            else:
                v_targ = targets[k]['target']
                if targets[k]["minimize"] == True:
                    if v_targ > 0:
                        r = (v-v_targ)/(v + (-1)*3*v_targ)
                    else:
                        r = (v-v_targ)/(v + v_targ)
                    
                    if v > v_targ+ abs(v_targ) or r<-1:
                        r=-1
                else:
                    if v_targ > 0:
                        r = (v-v_targ)/(v + v_targ)
                    else:
                        r = (v-v_targ)/(v + (-1)*3*v_targ)
                    
                    if v < v_targ - abs(v_targ) or r<-1:
                        r= -1

                reward += min(r, 0)
                num_keys +=1
                if targets[k]["constraint"] == False:
                    opt_reward.append(max(r, 0)*15)

        if reward >= 0:
            if all(r>=0 for r in opt_reward): 
                # if all objectives are above target, give a bonus
                opt_reward_sum = sum(opt_reward)
                reward+= opt_reward_sum        

        if reward >= 0:
            reward = reward + completion_reward

        reward_list.append(reward)
    
    if len(reward_list)>0:
        max_reward = max(reward_list)
        idx_max = reward_list.index(max_reward)
        if max_reward > 0: 
            max_reward = max_reward
        elif max_reward<0 and max_reward>-0.8*num_keys: 
            max_reward = completion_reward #small reward to encourage iteration when you don't get 1 on everything
        else:
            max_reward = 3
    else:
        max_reward = 3 
        idx_max = None
    return max_reward, idx_max

