import itertools
import enum
import math
from copy import deepcopy
from dataclasses import dataclass, field
from collections import defaultdict

import networkx as nx

from matplotlib import colormaps
import matplotlib.pyplot as plt
import numpy as np


NODE_COLORMAP = colormaps['Set3']
EDGE_COLORMAP = colormaps['tab10']


def relabel(edge, index):
    if edge[0] > index:
        return edge[0] - 1, edge[1] - 1
    if edge[1] > index:
        return edge[0], edge[1] - 1
    return edge


@dataclass
class GraphState:
    node_types: list[int] = field(default_factory=list)
    edge_types: list[int] = field(default_factory=list)
    edge_list: list[tuple[int, int]] = field(default_factory=list) # assume edge[0] < edge[1]

    def __post_init__(self):
        self.num_nodes: int = len(self.node_types)
        self.num_edges: int = len(self.edge_list)
        self._edge_set: set[tuple[int, int]] = set(self.edge_list)
        self._set_degree()

    def _set_degree(self):
        degree = [0] * self.num_nodes
        for i, j in self.edge_list:
            degree[i] += 1
            degree[j] += 1
        self.degree = degree

    def add_node(self, node_type: int):
        self.node_types.append(node_type)
        self.degree.append(0)
        self.num_nodes += 1

    def add_edge(self, i: int, j: int, edge_type: int):
        edge = (i, j) if i < j else (j, i)
        assert edge not in self._edge_set
        self.edge_list.append(edge)
        self.edge_types.append(edge_type)
        self._edge_set.add(edge)
        self.num_edges += 1
        self.degree[i] += 1
        self.degree[j] += 1
    
    def remove_node(self, source):
        no_remove = [i for i, edge in enumerate(self.edge_list) if source not in edge]
        self.node_types.pop(source)
        self.edge_list = [relabel(self.edge_list[i], source) for i in no_remove]
        self.edge_types = [self.edge_types[i] for i in no_remove]
        self.__post_init__()
    
    def remove_edge(self, i, j):
        edge = (i, j) if i < j else (j, i)
        self._edge_set.remove(edge)
        remove_idx = self.edge_list.index(edge)
        self.edge_list.pop(remove_idx)
        self.edge_types.pop(remove_idx)
        self.num_edges -= 1
        self.degree[i] -= 1
        self.degree[j] -= 1

    def get_non_edge_list(self):
        non_edges = []
        for i in range(self.num_nodes):
            for j in range(i + 1, self.num_nodes):
                edge = (i, j)
                if edge not in self._edge_set:
                    non_edges.append(edge)
        return non_edges

    def to_nx(self):
        graph = nx.from_edgelist(self.edge_list)
        # add missing nodes and set attributes
        for i, node_type in zip(range(self.num_nodes), self.node_types):
            graph.add_node(i, node_type=node_type)
        edge_attr = {edge: {'edge_type': self.edge_types[i]} for i, edge in enumerate(self.edge_list)}
        nx.set_edge_attributes(graph, edge_attr)
        return graph

    def nx_draw(self, figsize=(3, 3), with_labels=True):
        g = self.to_nx()
        node_color = [NODE_COLORMAP(g.nodes[i]['node_type']) for i in g.nodes]
        edge_color = [EDGE_COLORMAP(g.edges[i]['edge_type']) for i in g.edges]
        plt.figure(figsize=figsize)
        return nx.draw(
            g,
            node_color=node_color,
            edge_color=edge_color,
            with_labels=with_labels,
        )


class ActionType(enum.Enum):
    Stop = enum.auto()
    AddNode = enum.auto()
    AddEdge = enum.auto()
    RemoveNode = enum.auto()
    RemoveEdge = enum.auto()


@dataclass
class GraphAction:
    type: ActionType = None
    source: int = None
    target: int = None
    node_type: int = None
    edge_type: int = None

    def is_sane(self):
        if self.type == ActionType.AddNode:
            assert self.node_type is not None
            assert self.source is not None
        elif self.type == ActionType.AddEdge:
            assert self.edge_type is not None
            assert self.source is not None
            assert self.target is not None


@dataclass
class Children:
    states: list[GraphState] = field(default_factory=list)
    actions: list[GraphState] = field(default_factory=list)
    duplicates: defaultdict[list] = field(default_factory=lambda: defaultdict(list))

    def __len__(self):
        return len(self.states)
    
    def find_action_index(self, action):
        found = False
        try:
            idx = self.actions.index(action)
            found = True
        except ValueError:
            for idx, actions in self.duplicates.items():
                try:
                    actions.index(action)
                    found = True
                    break
                except ValueError:
                    continue
        if not found:
            raise ValueError(f"{action} is not found.")
        return idx


@dataclass
class Parents(Children):
    pass
    

def num_cycles(state: GraphState) -> float:
    nx_graph = state.to_nx()
    return 1.0 + float(len(nx.cycle_basis(nx_graph)))


def diameter(state: GraphState) -> float:
    nx_graph = state.to_nx()
    return float(nx.diameter(nx_graph))


def uniform(state: GraphState) -> float:
    return 1.0


class GraphEnv:
    def __init__(
        self,
        num_node_types=1,
        num_edge_types=1,
        max_nodes=20,
        max_edges=30,
        max_degree=5,
        remove_duplicates=False,
        min_reward=1e-8,
        reward_exponent=1.0,
        reward_name='num_cycles',
    ):
        assert max_nodes > 0
        assert max_edges > 0

        self.num_node_types = num_node_types
        self.num_edge_types = num_edge_types
        self.max_nodes = max_nodes
        self.max_edges = max_edges
        self.max_degree = max_degree
        self.remove_duplicates = remove_duplicates
        self.min_reward = min_reward
        self.reward_exponent = reward_exponent
        self.reward_name = reward_name

        if reward_name == 'num_cycles':
            self._reward_fn = num_cycles
        elif reward_name == 'diameter':
            self._reward_fn = diameter
        elif reward_name == 'uniform':
            self._reward_fn = uniform
        elif reward_name == 'random_cycle':
            from gflownet.monitor import all_states_info, smiles_hash
            import random

            self._reward_fn = num_cycles
            states_info = all_states_info(self)
            rewards = [states_info[smi]['reward'] for smi in states_info if smi != '']
            smiles = [smi for smi in states_info if smi != '']
            random.shuffle(rewards)
            smi2rewards = dict(zip(smiles, rewards))
            
            def random_reward(state):
                smi = smiles_hash(state)
                return smi2rewards.get(smi, 0.0)

            self._reward_fn = random_reward


        else:
            raise ValueError(f'reward_name={reward_name}')
        
    
    def new(self):
        return GraphState()
    
    def stop_action(self):
        return GraphAction(ActionType.Stop)
    
    def step(self, state: GraphState, action: GraphAction) -> tuple[GraphState, bool]:
        next_state = deepcopy(state)
        if action.type == ActionType.AddNode:
            assert action.node_type < self.num_node_types
            next_state.add_node(action.node_type)
            if action.source is not None:
                target = next_state.num_nodes - 1
                next_state.add_edge(action.source, target, action.edge_type)
        elif action.type == ActionType.AddEdge:
            assert action.edge_type < self.num_edge_types
            next_state.add_edge(action.source, action.target, action.edge_type)
        elif action.type == ActionType.RemoveNode:
            next_state.remove_node(action.source)
        elif action.type == ActionType.RemoveEdge:
            next_state.remove_edge(action.source, action.target)
        elif action.type == ActionType.Stop: 
            pass
        else:
            raise ValueError()
        return next_state
        
    def reverse_action(self, state, action):
        if action.type == ActionType.RemoveEdge:
            action = GraphAction(ActionType.AddEdge, source=action.source, target=action.target, edge_type=action.edge_type)
        elif action.type == ActionType.RemoveNode:
            if state.num_nodes == 1:
                action = GraphAction(ActionType.AddNode, node_type=state.node_types[0])
            else:
                for edge_type, edge in zip(state.edge_types, state.edge_list):
                    if action.source in edge:
                        target = edge[1] if edge[0] == action.source else edge[0]
                        break
                node_type = state.node_types[action.source]
                source = target - int(action.source < target)
                action = GraphAction(ActionType.AddNode, source, node_type=node_type, edge_type=edge_type)
        elif action.type == ActionType.AddEdge:
            action = GraphAction(ActionType.RemoveEdge, source=action.source, target=action.target, edge_type=action.edge_type)
        elif action.type == ActionType.AddNode:
            if action.source is None:
                action = GraphAction(ActionType.RemoveNode, source=0)
            else:
                action = GraphAction(ActionType.RemoveNode, source=state.num_nodes)
        elif action.type == ActionType.Stop:
            action = action
        return action
    
    def log_reward(self, state: GraphState) -> float:
        reward = self._reward_fn(state)
        reward = max(reward, self.min_reward)
        return self.reward_exponent * math.log(reward)
    
    def parents_actions(self, state: GraphState) -> list[GraphAction]:
        if state.num_nodes == 1:
            action = GraphAction(ActionType.RemoveNode, source=0)
            return [action]

        actions = []
        
        # ActionType.RemoveNode
        for source in range(state.num_nodes):
            if state.degree[source] == 1:
                action = GraphAction(ActionType.RemoveNode, source=source)
                actions.append(action)

        # ActionType.RemoveEdge
        for idx, (u, v) in enumerate(state.edge_list):
            if state.degree[u] > 1 and state.degree[v] > 1:
                action = GraphAction(ActionType.RemoveEdge, source=u, target=v, edge_type=state.edge_types[idx])
                new_state = self.step(state, action)
                new_graph = new_state.to_nx()
                if nx.algorithms.is_connected(new_graph):
                    actions.append(action)
        return actions
    
    def parents(self, state: GraphState) -> Parents:
        actions = self.parents_actions(state)
        next_states = [self.step(state, act) for act in actions]
        if self.remove_duplicates:
            parents = Parents()
            parent_graphs = []
            for action, new_state in zip(actions, next_states):
                new_graph = new_state.to_nx()
                is_isomorphic = False
                for i, gp in enumerate(parent_graphs):
                    if nx.is_isomorphic(gp, new_graph, lambda a, b: a == b, lambda a, b: a == b):
                        parents.duplicates[i].append(action)
                        is_isomorphic = True
                        break
                if not is_isomorphic:
                    parent_graphs.append(new_graph)
                    parents.states.append(new_state)
                    parents.actions.append(action)
        else:
            parents = Parents(states=next_states, actions=actions)
        return parents

    
    def children_actions(self, state: GraphState) -> list[GraphAction]:

        if state.num_nodes == 0:
            for node_type in range(self.num_node_types):
                action = GraphAction(ActionType.AddNode, node_type=node_type)
            return [action]
        
        action = self.stop_action()
        actions = [action]
        
        if state.num_edges < self.max_edges:
            allowable_nodes = [i for i, deg in enumerate(state.degree) if deg < self.max_degree]
            
            # AddNode
            if state.num_nodes < self.max_nodes:
                for i in allowable_nodes:
                    for node_type in range(self.num_node_types):
                        for edge_type in range(self.num_edge_types):
                            action = GraphAction(ActionType.AddNode, source=i, node_type=node_type, edge_type=edge_type)
                            actions.append(action)
            # AddEdge
            for a, b in itertools.combinations(allowable_nodes, 2):
                if (a, b) not in state._edge_set:
                    for edge_type in range(self.num_edge_types):
                        action = GraphAction(ActionType.AddEdge, source=a, target=b, edge_type=edge_type)
                        actions.append(action)
        return actions
        
    def children(self, state: GraphState) -> Children:
        actions = self.children_actions(state)
        next_states = [self.step(state, act) for act in actions]
        if self.remove_duplicates:
            children = Children()
            child_graphs = []
            for action, new_state in zip(actions, next_states):
                new_graph = new_state.to_nx()
                is_isomorphic = False
                for i, gp in enumerate(child_graphs):
                    if nx.is_isomorphic(gp, new_graph, lambda a, b: a == b, lambda a, b: a == b):
                        children.duplicates[i].append(action)
                        is_isomorphic = True
                        break
                if not is_isomorphic:
                    child_graphs.append(new_graph)
                    children.states.append(new_state)
                    children.actions.append(action)
        else:
            children = Children(states=next_states, actions=actions)
        return children
    

# from rdkit.Chem import QED
# from rdkit import Chem


# DEFAULT_ATOM_TYPES = ['C', 'N', 'O', 'F', 'P', 'S']
# DEFAULT_BOND_TYPES = [
#     Chem.rdchem.BondType.SINGLE, 
#     Chem.rdchem.BondType.DOUBLE, 
#     Chem.rdchem.BondType.TRIPLE
# ]

# class MolEnv(GraphEnv):
#     def __init__(
#         self,
#         atoms=DEFAULT_ATOM_TYPES,
#         num_bond_types=3,
#         charges=[0, 1, -1],
#         max_nodes=40,
#         max_edges=50,
#         remove_duplicates=False,
#         min_reward=1e-8,
#         reward_exponent=1.0,
#         reward_name='qed'
#     ):
#         self.atoms = atoms
#         self.bonds = DEFAULT_BOND_TYPES[:num_bond_types]
#         self.charges = charges
#         self.max_nodes = max_nodes
#         self.max_edges = max_edges
#         self.remove_duplicates = remove_duplicates
#         self.min_reward = min_reward
#         self.reward_exponent = reward_exponent
#         self.reward_name = reward_name
        
#         self.num_node_types = len(atoms) * len(charges)
#         self.num_edge_types = num_bond_types

#         pt = Chem.GetPeriodicTable()
#         self.max_atom_valence = {a: max(pt.GetValenceList(a)) for a in atoms}

#         if reward_name == 'qed':
#             self._reward_fn = QED.qed
#         else:
#             raise ValueError(f'reward_name={reward_name}')        

            
#     def _get_atom_valence(self, state: GraphState) -> list[int]:
#         valence = [0] * state.num_nodes
#         for t, (u, v) in zip(state.edge_types, state.edge_list):
#             valence[u] += t + 1
#             valence[v] += t + 1
#         return valence
        

#     def _node_type_to_atom(self, node_type: int) -> tuple[str, int]:
#         atom_idx, charge_idx = divmod(node_type, len(self.charges))
#         atom_str, charge = self.atoms[atom_idx], self.charges[charge_idx]
#         return atom_str, charge

    
#     def _node_type_to_max_valence(self, node_type: int) -> int:
#         atom_str, charge = self._node_type_to_atom(node_type)
#         if atom_str == 'N' and charge == 1: # special rule for nitrogen
#             return 5
#         return self.max_atom_valence.get(atom_str) - abs(charge)
                
    

#     def state_to_mol(self, state: GraphState) -> Chem.Mol:
#         m = Chem.RWMol()
#         for t in state.node_types:
#             atom_str, charge = self._node_type_to_atom(t)
#             atom = Chem.Atom(atom_str)
#             atom.SetFormalCharge(charge)
#             m.AddAtom(atom)
#         for t, (u, v) in zip(state.edge_types, state.edge_list):
#             bond_type = DEFAULT_BOND_TYPES[t]
#             m.AddBond(u, v, order=bond_type)
#         return m
    

#     def mol_to_state(self, mol):
#         Chem.Kekulize(mol, clearAromaticFlags=True)
#         node_types = []
#         for atom in mol.GetAtoms():
#             atom_str  = atom.GetSymbol()
#             charge = np.clip(atom.GetFormalCharge(), self.charges[0], self.charges[-1])
#             atom_idx = self.atoms.index(atom_str)
#             node_types.append(atom_idx * len(self.charges) + charge)        
#         edge_types = []
#         edge_list = []
#         for bond in mol.GetBonds():
#             edge_types.append(self.bonds.index(bond.GetBondType()))
#             i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
#             edge = (i, j) if i < j else (j, i)
#             edge_list.append(edge)
#         return GraphState(node_types, edge_types, edge_list)


#     def log_reward(self, state: GraphState) -> float:
#         mol = self.state_to_mol(state)
#         reward = self._reward_fn(mol)
#         return self.reward_exponent * math.log(reward + self.min_reward)
        

#     def children(self, state: GraphState) -> Children:

#         if state.num_nodes == 0:
#             children = Children()
#             for node_type in range(self.num_node_types):
#                 action = GraphAction(ActionType.AddNode, node_type=node_type)
#                 new_state = self.step(state, action)
#                 children.states.append(new_state)
#                 children.actions.append(action)
#             return children
        

#         action = self.stop_action()
#         stop_state = self.step(state, action)
#         children = Children(states=[stop_state], actions=[action])
#         offset = 1

#         def add_child():
#             new_state = self.step(state, action)
#             if self.remove_duplicates:
#                 new_graph = new_state.to_nx()
#                 for i, gp in enumerate(child_graphs):
#                     if nx.is_isomorphic(gp, new_graph, lambda a, b: a == b, lambda a, b: a == b):
#                         children.duplicates[i+offset].append(action)
#                         return
#                 child_graphs.append(new_graph)
#             children.states.append(new_state)
#             children.actions.append(action)

#         if state.num_edges < self.max_edges:
#             atom_val = self._get_atom_valence(state)
#             impl_val = [self._node_type_to_max_valence(t) - atom_val[i] for i, t in enumerate(state.node_types)]
#             allowable_nodes = [i for i, v in enumerate(impl_val) if v > 0]
            
#             # AddNode
#             if state.num_nodes < self.max_nodes:
#                 child_graphs = []
#                 for i in allowable_nodes:
#                     for node_type in range(self.num_node_types):
#                         max_val = self._node_type_to_max_valence(node_type)
#                         for edge_type in range(self.num_edge_types):
#                             if edge_type < min(impl_val[i], max_val):
#                                 action = GraphAction(ActionType.AddNode, source=i, node_type=node_type, edge_type=edge_type)
#                                 add_child()
            
#             # AddEdge
#             child_graphs = []
#             offset = len(children)
#             for a, b in itertools.combinations(allowable_nodes, 2):
#                 if (a, b) not in state._edge_set:
#                     for edge_type in range(self.num_edge_types):
#                         if edge_type < min(impl_val[a], impl_val[b]):
#                             action = GraphAction(ActionType.AddEdge, source=a, target=b, edge_type=edge_type)
#                             add_child()
#         return children

#     def parse_smiles(self, smiles: list[str]):
#         mols = []
#         for smi in smiles:
#             try:
#                 mols.append(Chem.MolFromSmiles(smi))
#             except:
#                 pass
#         states = []
#         for m in mols:
#             try:
#                 if m.GetNumAtoms() <= self.max_nodes and m.GetNumBonds() <= self.max_edges:
#                     states.append(self.mol_to_state(m))
#             except:
#                 pass

#         print(f'states: {len(states)}; mols: {len(mols)}; smiles: {len(smiles)}')
#         return states