from collections import OrderedDict
import numpy as np
import copy
import pickle
from networks.gcn_policy import GCNNestedPolicyPG
import gym
from infrastructure import pytorch_util as ptu
from typing import Dict, Tuple, List, Any, Union
from torch_geometric.data import Data
import matplotlib.pyplot as plt
import torch.multiprocessing as mp
import torch
import time
from pathlib import Path
import networkx as nx
from mlcirc_utils import VisualizationUtils 
############################################
############################################


def sample_trajectory(
    env: gym.Env, policy: GCNNestedPolicyPG, max_length: int, 
    random_start: bool=False, 
    random_circuit: bool = False,
    replay_buffer = None,
    insert = False,
    domain = True, 
    circuit_index = None,
    nodes_sample = 0, 
    queue = None, 
    rank = None,
    iteration = None, 
    traj_no = None,
    sample_nodes = None
) -> Dict[str, np.ndarray]:
    """Sample a rollout in the environment from a policy."""
    print("sampling", sample_nodes)
    ob, ground_graph, graph_sampled, mask_obj, symmetric_nodes = env.reset(random_start, random_circuit, circuit_index, nodes_sample, 
                                                          iteration=iteration, traj_no=traj_no, 
                                                          sample_nodes=sample_nodes)
    obs, acs, rewards, next_obs, terminals, graphs, valid, masks, symmetries = [], [], [], [], [], [], [], [], []
    steps = 0
    env_step_prev = 0
    while True:
        # use the most recent ob to decide what to do
        ac, mask_array = policy.get_action(ob, mask_obj, sym=symmetric_nodes)

        # take that action and get reward and next ob
        next_ob, rew, done, info, env_steps, graph, next_mask_obj, next_symmetric_nodes = env.step(ac, domain)
        val = 1 if env_steps>env_step_prev else 0
        # rollout can end due to done, or due to max_length
        steps += 1
        rollout_done = done or steps > max_length  # HINT: this is either 0 or 1

        # record result of taking that action
        graphs.append(copy.deepcopy(graph))
        obs.append(ob)
        acs.append(ac)
        rewards.append(rew)
        next_obs.append(next_ob)
        terminals.append(done)
        valid.append(val)
        symmetries.append(symmetric_nodes)

        env_step_prev = env_steps
        # end the rollout if the rollout ended
        if insert and replay_buffer is not None:
            replay_buffer.insert(
            observation=ob,
            action=np.asarray(ac),
            reward=rew,
            next_observation=next_ob,
            done=done,
            valid=val,
            mask=mask_array,
            mask_obj=mask_obj,
        )
        ob = next_ob  # jump to next timestep
        mask_obj = next_mask_obj
        symmetric_nodes = next_symmetric_nodes
        
        masks.append(mask_array)

        if rollout_done:
            break
    episode_statistics = {"l": steps, "r": np.sum(rewards)} # will include the invalid steps
    if "episode" in info: 
        episode_statistics.update(info["episode"])

    o_list = []
    for o in obs: 
        o_sub = []
        if o.node_type is not None:
            o_sub.append(np.array(o.node_type.cpu()))
        else:
            o_sub.append(None)
        if o.edge_index is not None:
            o_sub.append(np.array(o.edge_index.cpu()))
        else:   
            o_sub.append(None)
        if o.edge_attr is not None:
            o_sub.append(np.array(o.edge_attr.cpu()))
        else:
            o_sub.append(None)
        o_list.append(o_sub)
    
    no_list = []
    for no in next_obs: 
        no_sub = []
        if no.node_type is not None:
            no_sub.append(np.array(no.node_type.cpu()))
        else:
            no_sub.append(None)
        if no.edge_index is not None:
            no_sub.append(np.array(no.edge_index.cpu()))
        else:   
            no_sub.append(None)
        if no.edge_attr is not None:
            no_sub.append(np.array(no.edge_attr.cpu()))
        else:
            no_sub.append(None)
        no_list.append(no_sub)

    return_dict = {
        "observation": o_list, 
        "reward": np.array(rewards, dtype=np.float32),
        "action": np.array(acs, dtype=np.float32),
        "next_observation": no_list,
        "terminal": np.array(terminals, dtype=np.float32),
        "graph": graphs,
        "graph_sampled": copy.deepcopy(graph_sampled),
        "ground_truth": ground_graph,
        "valid": np.array(valid),
        "mask": masks,
        "symmetric_nodes": symmetries,
        "episode_statistics": episode_statistics,
    }
    env.close()

    if queue is None: 
        return return_dict
    else:
        try:
            queue.put(return_dict)
            queue.put(None)
            queue.task_done() 

        except mp.Queue.full:
            print("Queue is full! Skipping.")
            queue.task_done() 


def sample_trajectories_process(
    env: gym.Env,
    policy: GCNNestedPolicyPG,
    min_timesteps_per_batch: int,
    max_length: int,
    random_start=False,
    random_circuit=False,
    replay_buffer=None,
    insert=False,
    domain=True, 
    circuit_index = None,
    nodes_sample = 1, 
    circuit_order = [],
    sample_order = [],
    check_curriculum = [], 
    queue = None, 
    rank = None, 
    gpu_id=0, 
    iteration= None,
    sample_nodes = None) -> Tuple[List[Dict[str, np.ndarray]], int]:
    """Collect rollouts using policy until we have collected min_timesteps_per_batch steps."""
    print(f"started rank {rank}", sample_nodes)
    if queue is not None:
        ptu.init_gpu(gpu_id=gpu_id)
    timesteps_this_batch = 0
    trajs = []

    t_start = time.time()
    traj_no = 0
    co_idx = 0
    check_curriculum_traj = []
    while timesteps_this_batch < min_timesteps_per_batch:
        # collect rollout
        if len(circuit_order)>0 and co_idx < len(circuit_order):
            circuit_index = circuit_order[co_idx]
            sample_nodes = sample_order[co_idx]
            traj = sample_trajectory(env, policy, max_length, 
                                random_start, random_circuit, replay_buffer, insert, domain,
                                circuit_index, sample_nodes=sample_nodes, 
                                iteration=iteration, traj_no=traj_no)
            check_curriculum_traj.append(check_curriculum[co_idx])
            co_idx += 1
            

        elif circuit_index is not None:
            traj = sample_trajectory(env, policy, max_length, 
                                random_start, random_circuit, replay_buffer, insert, domain,
                                circuit_index, nodes_sample=nodes_sample, sample_nodes=sample_nodes,
                                iteration=iteration, traj_no=traj_no)
            check_curriculum_traj.append(0)

        else:
            traj = sample_trajectory(env, policy, max_length, 
                                random_start, random_circuit, replay_buffer, insert, domain)
            check_curriculum_traj.append(0)
        traj_no += 1
        trajs.append(traj)
        timesteps_this_batch += get_traj_length(traj)
        
    t_end = time.time()
    print("trajectory time: ", t_end-t_start)

    if queue is None:
        return [], trajs, timesteps_this_batch
    else:
        print("[Worker] Putting result on queue...")
        queue.put({"check_curriculum": check_curriculum_traj, "traj":trajs, "tsteps":timesteps_this_batch})
        print("[Worker] Put complete.")


def sample_trajectories(
    envs_list: Union[List[Any], gym.Env], #gym.Env
    policy: GCNNestedPolicyPG,
    min_timesteps_per_batch: int,
    max_length: int,
    random_start=False,
    random_circuit=False,
    replay_buffer=None,
    insert=False,
    domain=True, 
    circuit_index = None,
    nodes_sample = 1, 
    circuit_order = None,
    sample_order = None,
    check_curriculum_reward = False,
    num_workers=4, 
    gpu_id = 0, 
    iteration = None, 
    sample_nodes = None) -> Tuple[List[Dict[str, np.ndarray]], int]:
    """Collect rollouts using policy until we have collected min_timesteps_per_batch steps."""
    timesteps_this_batch = 0
    trajs = []
    curriculum_check = []

    t_start = time.time()
    if num_workers>1:  # Number of parallel environments
        queue = mp.Queue()

        processes = []
        if circuit_order is not None:
            # Split circuit_order and sample_order into num_workers parts, handling inconsistent dimensions
            if check_curriculum_reward: 
                check_curriculum_list = [0 for _ in range(len(circuit_order)//2)] + \
                                         [1 for _ in range(len(circuit_order) - len(circuit_order)//2)]
            else:
                check_curriculum_list = [0 for _ in range(len(circuit_order))]
            circuit_order_division = []
            sample_order_division = []
            check_curriculum_division = []
            total = len(circuit_order)
            chunk_sizes = [(total // num_workers) + (1 if x < total % num_workers else 0) for x in range(num_workers)]
            start = 0
            for size in chunk_sizes:
                end = start + size
                circuit_order_division.append(circuit_order[start:end])
                sample_order_division.append(sample_order[start:end])
                check_curriculum_division.append(check_curriculum_list[start:end])
                start = end

        for rank in range(num_workers):
            pol = copy.deepcopy(policy)
            pol.load_state_dict(policy.state_dict())
            if circuit_order is not None:
                p = mp.Process(target=sample_trajectories_process, args=(envs_list[rank], pol, min_timesteps_per_batch//num_workers, max_length, 
                                    random_start, random_circuit, replay_buffer, insert, domain,
                                    None, 1, circuit_order_division[rank], 
                                    sample_order_division[rank],
                                    check_curriculum_division[rank],
                                    queue, rank, gpu_id, iteration))

            elif circuit_index is not None:
                p = mp.Process(target=sample_trajectories_process, 
                                    args=(envs_list[rank], pol, min_timesteps_per_batch//num_workers, max_length, 
                                    random_start, random_circuit, replay_buffer, insert, domain,
                                    circuit_index, nodes_sample, [], [], [],
                                    queue, rank, gpu_id, iteration, sample_nodes))

            else:
                p = mp.Process(target=sample_trajectories_process, args=(envs_list[rank], pol, min_timesteps_per_batch//num_workers, max_length, 
                                    random_start, random_circuit, replay_buffer, insert, domain,
                                    None, 1, [], [], [],
                                    queue, rank, gpu_id, iteration))
            p.start()
            processes.append(p)

        f = 0
        for i in range(num_workers):
            got = False
            tries = 0
            while tries<600 and not got:
                tries += 1
                try:
                    traj = queue.get(timeout=0.5)

                    for indx, t in enumerate(traj["traj"]):
                        new_obs = []
                        for o in t["observation"]:
                            data = Data()
                            data.node_type = torch.tensor(o[0])
                            data.edge_index = torch.tensor(o[1])
                            data.edge_attr = torch.tensor(o[2])
                            new_obs.append(data)
                        traj["traj"][indx]["observation"] = new_obs

                        new_nobs = []
                        for no in t["next_observation"]:
                            data = Data()
                            data.node_type = torch.tensor(no[0])
                            data.edge_index = torch.tensor(no[1])
                            data.edge_attr = torch.tensor(no[2])
                            new_nobs.append(data)
                        traj["traj"][indx]["next_observation"] = new_nobs

                    if isinstance(traj, dict):
                        f += 1
                        trajs = trajs + traj["traj"]
                        # count steps
                        timesteps_this_batch += traj["tsteps"]
                        curriculum_check += traj["check_curriculum"]
                    got = True
                    print(f"got one{f}")
                except Exception as e:
                    print(f"waiting..... {e}")
                    time.sleep(2)

        for p in processes:
            p.join(timeout=10)  # Ensures that `join()` doesn't hang forever
            if p.is_alive():
                print(f"Process {p.pid} is still running! Terminating...")
                p.terminate()  # Force kill if still running
                p.join() 
        print("processes joined")

    else:
        traj_no = 0
        co_idx = 0
        while timesteps_this_batch < min_timesteps_per_batch:
            # collect rollout
            # envs_list is just the one gym env environment in this case
            if circuit_order is not None and co_idx < len(circuit_order): 
                traj = sample_trajectory(envs_list, policy, max_length, 
                                    random_start, random_circuit, replay_buffer, insert, domain,
                                    circuit_index=circuit_order[co_idx], sample_nodes=sample_order[co_idx],
                                    iteration=iteration, traj_no=traj_no)
                co_idx += 1
                print(f"circuit_order: {circuit_order[co_idx-1]}, sample_order: {sample_order[co_idx-1]}")
            elif circuit_index is not None:
                traj = sample_trajectory(envs_list, policy, max_length, 
                                    random_start, random_circuit, replay_buffer, insert, domain,
                                    circuit_index=circuit_index, nodes_sample=nodes_sample, 
                                    sample_nodes=sample_nodes,
                                    iteration=iteration, traj_no=traj_no)

            else:
                traj = sample_trajectory(envs_list, policy, max_length, 
                                    random_start, random_circuit, replay_buffer, insert, domain, iteration=iteration, 
                                    traj_no=traj_no, nodes_sample=nodes_sample)
            

            new_obs = []
            for o in traj["observation"]:
                data = Data()
                data.node_type = torch.tensor(o[0])
                if o[1] is None:
                    data.edge_index = None
                    data.edge_attr = None
                else:
                    data.edge_index = torch.tensor(o[1])
                    data.edge_attr = torch.tensor(o[2])
                new_obs.append(data)
            traj["observation"] = new_obs

            new_nobs = []
            for no in traj["next_observation"]:
                data = Data()
                data.node_type = torch.tensor(no[0])
                if no[1] is None:
                    data.edge_index = None
                    data.edge_attr = None
                else:
                    data.edge_index = torch.tensor(no[1])
                    data.edge_attr = torch.tensor(no[2])
                new_nobs.append(data)
            traj["next_observation"] = new_nobs

            traj_no += 1
            trajs.append(traj)
            timesteps_this_batch += get_traj_length(traj)

        if check_curriculum_reward: 
            curriculum_check = [0 for _ in range(len(circuit_order)//2)] + \
                                     [1 for _ in range(len(circuit_order) - len(circuit_order)//2)]
        else: 
            curriculum_check = [0 for _ in range(len(trajs))]
    t_end = time.time()
    print("trajectory total time: ", t_end-t_start)
    return trajs, curriculum_check, timesteps_this_batch, t_end-t_start

def sample_n_trajectories(
    env: gym.Env, policy: GCNNestedPolicyPG, ntraj: int, max_length: int, render: bool = False
):
    """Collect ntraj rollouts."""
    trajs = []
    for _ in range(ntraj):
        # collect rollout
        traj = sample_trajectory(env, policy, max_length, render)
        trajs.append(traj)
    return trajs


def compute_metrics(trajs, eval_trajs):
    """Compute metrics for logging."""

    # returns, for logging
    train_returns = [traj["reward"].sum() for traj in trajs]
    eval_returns = [eval_traj["reward"].sum() for eval_traj in eval_trajs]

    train_ep_lens = [traj["episode_statistics"]['l'] for traj in trajs]
    eval_ep_lens = [eval_traj["episode_statistics"]['l'] for eval_traj in eval_trajs]

    # decide what to log
    logs = OrderedDict()
    logs["Eval_AverageReturn"] = np.mean(eval_returns)
    logs["Eval_StdReturn"] = np.std(eval_returns)
    logs["Eval_MaxReturn"] = np.max(eval_returns)
    logs["Eval_MinReturn"] = np.min(eval_returns)
    logs["Eval_AverageEpLen"] = np.mean(eval_ep_lens)
    logs["Eval_MaxEpLen"] = np.max(eval_ep_lens)

    logs["Train_AverageReturn"] = np.mean(train_returns)
    logs["Train_StdReturn"] = np.std(train_returns)
    logs["Train_MaxReturn"] = np.max(train_returns)
    logs["Train_MinReturn"] = np.min(train_returns)
    logs["Train_AverageEpLen"] = np.mean(train_ep_lens)
    logs["Train_MaxEpLen"] = np.max(train_ep_lens)

    return logs

def soft_update_target_model(target_model, new_model, tau):

   for target_param, param in zip(
       target_model.parameters(), new_model.parameters()
   ):
       target_param.data.copy_(
           target_param.data * (1.0 - tau) + param.data * tau
       )

def convert_listofrollouts(trajs):
    """
    Take a list of rollout dictionaries and return separate arrays, where each array is a concatenation of that array
    from across the rollouts.
    """
    observations = np.concatenate([traj["observation"] for traj in trajs])
    actions = np.concatenate([traj["action"] for traj in trajs])
    next_observations = np.concatenate([traj["next_observation"] for traj in trajs])
    terminals = np.concatenate([traj["terminal"] for traj in trajs])
    concatenated_rewards = np.concatenate([traj["reward"] for traj in trajs])
    unconcatenated_rewards = [traj["reward"] for traj in trajs]
    return (
        observations,
        actions,
        next_observations,
        terminals,
        concatenated_rewards,
        unconcatenated_rewards,
    )


def get_traj_length(traj):
    return len(traj["reward"])

def append_to_pickle(file_path, new_data):
    # Step 1: Load existing data
    try:
        with open(file_path, 'rb') as f:
            existing_data = pickle.load(f)
    except (EOFError, FileNotFoundError):
        # If the file doesn't exist or is empty, initialize with an empty list
        existing_data = []

    # Step 2: Append new data
    existing_data.append(new_data)

    # Step 3: Write back to the pickle file
    with open(file_path, 'wb') as f:
        pickle.dump(existing_data, f)


def draw_graph(G, data_path, name, title, logger=None, itr=None, 
               a=None, r=None, mask=None, draw_bulk=False):
    from mlcirc_utils import VisualizationUtils as vu
    custom_color_dict = {
        "NET": "orange",
        "PFET": "blue",
        "NFET": "green",
        "RES": "sienna",
        "CAP": "mediumturquoise",
        "IND": "gold",
        # "TXFORMER": "mediumseagreen"  ,
        # "CONNECTS": "grey",
        "INPUT": "pink",
        "OUTPUT": "darkviolet",
        "VDD":"red",
        "GND": "grey"
    } 

    red_edges = []
    if not(a is None):
        if r != -12 or r is None:
            ac_edges = [n if n < G.get_num_nodes()-1 else G.get_num_nodes()-1 for n in a[0:2]]
            red_edges = [(ac_edges[0], ac_edges[1])]#, (ac_edges[1], ac_edges[0])]

    if not draw_bulk: 
        visgraph = copy.deepcopy(G)
        for e in list(visgraph.get_edges()):
            # Don't draw edges that are only connected to the bulk
            if (e[2]["one_hot_edge_attr"][0] == 1 or e[2]["one_hot_edge_attr"][-1] == 1) \
                and sum(e[2]["one_hot_edge_attr"]) == 1:
                visgraph.remove_edge(e[0], e[1])
    else: 
        visgraph = G
    
    if mask is not None: 
        prmask = []
        for m in mask: 
            prmask.append([0 if v<0.5 else 1 for v in m])

        mask_str=str(prmask[0])+"\n"+str(prmask[1])+"\n"+ \
                                            str(prmask[2])+"\n"+str(prmask[3])
    else: 
        mask_str=None
    if a is not None and mask is not None: 
        mask_str = str(a.flatten())+"\n"+mask_str
    elif a is not None:
        mask_str = str(a.flatten())
    fig = visualize_undir_stdgraph(graph=visgraph, 
                                        visualization_base_path=data_path, 
                                        visualization_name=name, 
                                        layout="spring",
                                        node_label_path=["custom_features", "one_hot_node_type"],
                                        edge_label_path = ["one_hot_edge_attr"],
                                        highlight_edges = red_edges,
                                        custom_color_dict = custom_color_dict,
                                        title=mask_str, 
                                        is_rl = True)

    plt.close("all")

def visualize_undir_stdgraph(graph, 
                                    visualization_base_path: Path, 
                                    visualization_name: str, 
                                    layout: str ="spring",
                                    node_label_path: List[str] = None,
                                    edge_label_path: List[str] = None,
                                    highlight_edges: List[Any] = None,
                                    arc: bool = True,
                                    custom_color_dict: Dict[str, str]=None,
                                    title= None,
                                    is_rl=False):
        
    G = nx.DiGraph()

    start_idx=0
    node_color_list = [[] for i in range(len(graph.get_nodes()))]
    node_labels = {}

    for idx, n in enumerate(graph.get_nodes()):
        if node_label_path is not None:
            node_features = graph.get_node_features(n[0], node_label_path)
            if is_rl: 
                node_features = [i for i, x in enumerate(node_features) if x == 1][0]
                node_features = list(custom_color_dict.keys())[node_features]
        else:
            node_features = None              # Get the features
        G.add_node(n[0], features=node_features)
        node_labels[n[0]] = f"{n[0]}: {node_features}" if node_label_path is not None else f"{n[0]}"
        comp_type = graph.get_node_features(n[0], ["component_type"])

        if custom_color_dict is not None:
            if is_rl:
                if comp_type.name != "NET":
                    # node_color_list.append(custom_color_dict[comp_type.name])
                    node_color_list[idx]=custom_color_dict[comp_type.name]
                else: 
                    index = [i for i, x in enumerate(graph.get_node_features(n[0],
                                                     ["custom_features","one_hot_node_type"])) if x == 1][0]
                    name = list(custom_color_dict.keys())[index]
                    node_color_list[idx]=custom_color_dict[name]
            else:
                node_color_list[idx]=custom_color_dict[comp_type.name]
        else:
            node_color_list[idx] =VisualizationUtils.COLOR_DICT[comp_type.name]     # Add the color
        start_idx += 1                                                  # Increment the start index                       

    edge_labels = {}
    for idx, e in enumerate(graph.get_edges()):
        G.add_edge(e[0], e[1])
        G.add_edge(e[1], e[0])
        label = graph.get_edge_features(e[0], e[1], edge_label_path)
        if is_rl:
            label = [i for i, x in enumerate(label) if x == 1]
            rl_order = {0:'B', 1:'D', 2:'G', 3:'S', 4:'P', 5:'M', 6:'TB'}
            label = [rl_order[i] for i in label]
        edge_labels[(e[0], e[1])] = label
        edge_labels[(e[1], e[0])] = label

    # Draw graph
    fig = plt.figure(figsize=(10, 10)) if len(graph.get_nodes())<10 else plt.figure(figsize=(15, 15))
    ax = fig.gca()
    pos = VisualizationUtils.LAYOUT_DICT[layout](G)
    
    nx.draw_networkx(G, pos, ax=ax, with_labels=(not (node_label_path is None)),
                     labels=node_labels, node_color=node_color_list)
    
    red_edges = [] if highlight_edges is None else highlight_edges
    for edge, feat in edge_labels.items():
        if edge in red_edges or (edge[1], edge[0]) in red_edges: 
            nx.draw_networkx_edge_labels(G, pos, ax=ax, edge_labels={edge: f"{feat}"}, 
                                        font_color='red')
        else: 
            nx.draw_networkx_edge_labels(G, pos, ax=ax, edge_labels={edge: f"{feat}"}, 
                                        font_color='k')

    if title is not None:
        ax.set_title(title)
        
    fig.savefig(visualization_base_path + f"/{visualization_name}.pdf", format='pdf')

    plt.close()
    return fig