import os
import copy
import json
import torch
import random
from RetroUtils.smiles2dgl import smiles_to_dglGraph
from hyperGraph_utils.utils import OurHypergraph


class RetroEnv:
    def __init__(self, data_name, RawDataFile_path, max_len=8):  # RawDataFile_path='../data/RawData'
        LG_hypergraph_info_path = os.path.join(RawDataFile_path, data_name, 'hypergraph.json')
        with open(LG_hypergraph_info_path, 'r') as f:
            self.hg_json = json.load(f)
        self.hypergraph_dhg = OurHypergraph(num_v=self.hg_json['num_v'],  # int
                                            e_list=self.hg_json["e_list"],  # [[], [], ...]
                                            x_tensor=torch.tensor(self.hg_json['v_fp'],
                                                                  dtype=torch.float32))  # num_v x 2048
        self.max_len = max_len

    def reset(self, data_dict):
        self.data_dict = data_dict
        self.rxn = self.data_dict['rxn']
        self.product = self.data_dict['product']
        self.reactants = self.data_dict['reactants']
        self.label_RcNodeIdx = self.data_dict['rc_node_idx']
        self.label_LgIdx = self.data_dict['lg_idx']

        # state
        self.product_dgl = smiles_to_dglGraph(smiles=self.product, seed=0, add_self_loop=False, Conformer=False)
        self.num_nodes_product = self.product_dgl.num_nodes()  # n
        self.num_nodes_hypergraph = self.hg_json['num_v']  # v
        self.product_fp = torch.tensor(self.data_dict['product_fp'], dtype=torch.float32).reshape(1, -1)  # 1xd
        Placed_RcNodeIdx = torch.zeros(self.num_nodes_product)  # n
        RcNodeIdx_list = []

        LgIdx_list = []

        # mask
        mask = torch.zeros(self.num_nodes_product + 1)  # (n+1)
        mask[-1] = 1
        mask = mask.reshape(1, -1)  # 1x(n+1)

        tag = 0
        t = 0

        self.state = [self.product_dgl, self.product_fp,
                      Placed_RcNodeIdx, RcNodeIdx_list, LgIdx_list,
                      tag, mask, t]
        return copy.deepcopy(self.state)

    def compute_Placed_tensor(self, num_nodes, node_set):  # num_nodes: int, node_set: [idx, idx, ...]
        assert -1 not in node_set, '-1 in node_set'
        placed = torch.zeros(num_nodes)  # n
        placed[node_set] = 1
        return placed  # n

    def compute_graph_one_hop_neighbor(self, node_set):  # node_set: [idx, idx, ...]
        one_hop_neighbor = set()
        for node in node_set:
            one_hop_nodes = self.product_dgl.successors(node)
            one_hop_neighbor = one_hop_neighbor | set(one_hop_nodes.tolist())
        one_hop_neighbor = one_hop_neighbor - set(node_set)
        one_hop_neighbor = list(one_hop_neighbor)
        one_hop_neighbor.sort()
        return one_hop_neighbor  # [idx, idx, ...]

    def compute_hypergraph_one_hop_neighbor(self, node_set):  # node_set: [idx, idx, ...]
        one_hop_node_neighbor = set()
        one_hop_edge_neighbor = set()
        for node in node_set:
            one_hop_edges = self.hypergraph_dhg.N_e(node)
            one_hop_edge_neighbor = one_hop_edge_neighbor | set(one_hop_edges.tolist())
        for edge in list(one_hop_edge_neighbor):
            one_hop_nodes = self.hypergraph_dhg.N_v(edge)
            one_hop_node_neighbor = one_hop_node_neighbor | set(one_hop_nodes.tolist())
        one_hop_node_neighbor = one_hop_node_neighbor - set(node_set)
        one_hop_node_neighbor = list(one_hop_node_neighbor)
        one_hop_node_neighbor.sort()
        return one_hop_node_neighbor  # [idx, idx, ...]

    def compute_mask(self, node_set, tag):  # ode_set: [idx, idx, ...], tag: int
        if tag == 0:  # tag = 0  stage: reaction center identification
            if len(node_set) == 0:
                mask = torch.zeros(self.num_nodes_product + 1)  # (n+1)
                mask[-1] = 1
                mask = mask.reshape(1, -1)  # 1x(n+1)
            else:
                assert -1 not in node_set, '-1 in node_set'
                one_hop_neighbor = self.compute_graph_one_hop_neighbor(node_set)  # [idx, idx, ...]
                if len(one_hop_neighbor) != 0:
                    assert max(
                        one_hop_neighbor) < self.num_nodes_product, f'error at one_hop_neighbor, one_hop_neighbor: {one_hop_neighbor}, num_nodes_product: {self.num_nodes_product}'
                mask = torch.ones(self.num_nodes_product + 1)  # (n+1)
                mask[one_hop_neighbor] = 0
                mask[-1] = 0
                mask = mask.reshape(1, -1)  # 1x(n+1)

        elif tag == 1:  # tag = 1  stage: leaving groups completion
            if len(node_set) == 0:
                mask = torch.zeros(self.hg_json['num_v'] + 1)
                mask[-1] = 1
                mask = mask.reshape(1, -1)  # 1x(v+1)  # !
            else:
                assert -1 not in node_set, '-1 in node_set'
                one_hop_neighbor = self.compute_hypergraph_one_hop_neighbor(node_set)  # [idx, idx, ...]
                if len(one_hop_neighbor) != 0:
                    assert max(one_hop_neighbor) < self.hg_json[
                        'num_v'], 'error at one_hop_neighbor, max(one_hop_neighbor) >= num_nodes_hypergraph'
                mask = torch.ones(self.hg_json['num_v'] + 1)  # (v+1)
                mask[one_hop_neighbor] = 0
                mask[-1] = 0
                mask = mask.reshape(1, -1)  # 1x(v+1)
        else:
            raise 'tag not equal 0 / 1'
        return mask

    def step(self, action):
        if self.state[-3] == 0:  # tag = 0  stage: reaction center identification
            if action != -1:  # not rc-stop-action
                assert action < self.num_nodes_product, 'action out of num_nodes_product'
                # change: Placed_RcNodeIdx, RcNodeIdx_list, mask, t
                RcNodeIdx_list = copy.deepcopy(self.state[3])  # [idx, idx, ...]
                assert action not in RcNodeIdx_list, f'action: {action} in RcNodeIdx_list: {RcNodeIdx_list}'
                RcNodeIdx_list.append(action)
                RcNodeIdx_list.sort()  # !
                Placed_RcNodeIdx = self.compute_Placed_tensor(num_nodes=self.num_nodes_product,
                                                              node_set=RcNodeIdx_list)  # !
                mask = self.compute_mask(node_set=RcNodeIdx_list, tag=self.state[-3])  # !
                t = copy.deepcopy(self.state[-1]) + 1  # !

                # not change: product_dgl, product_fp, LgIdx_list, tag
                LgIdx_list = copy.deepcopy(self.state[4])  # !
                tag = copy.deepcopy(self.state[5])  # !

                # next_state
                next_state = [self.product_dgl, self.product_fp,
                              Placed_RcNodeIdx, RcNodeIdx_list, LgIdx_list,
                              tag, mask, t]

                if next_state[-1] > self.max_len:
                    done = 1
                    reward = 0
                else:
                    done = 0
                    reward = 0

                self.state = copy.deepcopy(next_state)
                return reward, copy.deepcopy(next_state), done

            elif action == -1:  # rc-stop-action
                # change: tag, mask, t
                tag = 1  # !
                mask = torch.zeros(self.hg_json['num_v'] + 1)
                mask[-1] = 1
                mask = mask.reshape(1, -1)  # 1x(v+1)  # !
                t = copy.deepcopy(self.state[-1]) + 1  # !

                # not change: product_dgl, product_fp, Placed_RcNodeIdx, RcNodeIdx_list, LgIdx_list
                Placed_RcNodeIdx = copy.deepcopy(self.state[2])
                RcNodeIdx_list = copy.deepcopy(self.state[3])
                LgIdx_list = copy.deepcopy(self.state[4])

                next_state = [self.product_dgl, self.product_fp,
                              Placed_RcNodeIdx, RcNodeIdx_list, LgIdx_list,
                              tag, mask, t]

                if next_state[-1] > self.max_len:
                    done = 1
                    reward = 0
                else:
                    done = 0
                    reward = 0

                self.state = copy.deepcopy(next_state)
                return reward, copy.deepcopy(next_state), done

        elif self.state[-3] == 1:  # tag = 1  stage: leaving groups completion
            if action != -1:  # not lg-stop-action
                assert action < self.hg_json['num_v'], 'action out of num_nodes_hypergraph'
                # change: LgIdx_list, mask, t
                LgIdx_list = copy.deepcopy(self.state[4])
                assert action not in LgIdx_list, f'action: {action} in LgIdx_list: {LgIdx_list}'
                LgIdx_list.append(action)
                LgIdx_list.sort()  # !
                assert self.state[-3] == 1, 'error at self.state[-3] == 1'
                mask = self.compute_mask(node_set=LgIdx_list, tag=self.state[-3])  # !
                t = copy.deepcopy(self.state[-1]) + 1  # !

                # not change: product_dgl, product_fp, Placed_RcNodeIdx, RcNodeIdx_list, tag
                Placed_RcNodeIdx = copy.deepcopy(self.state[2])  # !
                RcNodeIdx_list = copy.deepcopy(self.state[3])  # !
                tag = copy.deepcopy(self.state[5])  # !

                next_state = [self.product_dgl, self.product_fp,
                              Placed_RcNodeIdx, RcNodeIdx_list, LgIdx_list,
                              tag, mask, t]

                if next_state[-1] > self.max_len:
                    done = 1
                else:
                    done = 0

                if done:
                    if set(RcNodeIdx_list) == set(self.label_RcNodeIdx) and set(LgIdx_list) == set(self.label_LgIdx):
                        reward = 1
                    else:
                        reward = 0
                else:
                    reward = 0

                self.state = copy.deepcopy(next_state)
                return reward, copy.deepcopy(next_state), done

            elif action == -1:  # lg-stop-action
                # change: tag, mask, t
                tag = 2  # !
                mask = torch.zeros(self.hg_json['num_v'] + 1)
                mask = mask.reshape(1, -1)  # 1x(v+1)  # !
                t = copy.deepcopy(self.state[-1]) + 1  # !

                # not change: product_dgl, product_fp, Placed_RcNodeIdx, RcNodeIdx_list, LgIdx_list
                Placed_RcNodeIdx = copy.deepcopy(self.state[2])
                RcNodeIdx_list = copy.deepcopy(self.state[3])
                LgIdx_list = copy.deepcopy(self.state[4])

                next_state = [self.product_dgl, self.product_fp,
                              Placed_RcNodeIdx, RcNodeIdx_list, LgIdx_list,
                              tag, mask, t]

                done = 1
                if set(RcNodeIdx_list) == set(self.label_RcNodeIdx) and set(LgIdx_list) == set(self.label_LgIdx):
                    reward = 1
                else:
                    reward = 0

                self.state = copy.deepcopy(next_state)
                return reward, copy.deepcopy(next_state), done

        else:
            raise f'tag not equal 0 / 1, tag = {self.state[-3]}'

    def generate_random_gt_trajectory(self):
        label_RcNodeIdx = copy.deepcopy(self.label_RcNodeIdx)  # [idx, idx, ...]
        label_LgIdx = copy.deepcopy(self.label_LgIdx)  # [idx, idx, ...]
        rc_trajectory = []
        lg_trajectory = []

        # rc
        action = random.choice(label_RcNodeIdx)
        rc_trajectory.append(action)
        label_RcNodeIdx.pop(label_RcNodeIdx.index(action))
        if len(label_RcNodeIdx) != 0:
            while True:
                one_hop_neighbor = self.compute_graph_one_hop_neighbor(rc_trajectory)
                candidate_nodes = set(one_hop_neighbor) & set(label_RcNodeIdx)
                action = random.choice(list(candidate_nodes))
                rc_trajectory.append(action)
                label_RcNodeIdx.pop(label_RcNodeIdx.index(action))
                if len(label_RcNodeIdx) == 0:
                    rc_trajectory.append(-1)
                    break
        elif len(label_RcNodeIdx) == 0:
            rc_trajectory.append(-1)

        # lg
        action = random.choice(label_LgIdx)
        lg_trajectory.append(action)
        label_LgIdx.pop(label_LgIdx.index(action))
        if len(label_LgIdx) != 0:
            while True:
                one_hop_neighbor = self.compute_hypergraph_one_hop_neighbor(lg_trajectory)
                candidate_nodes = set(one_hop_neighbor) & set(label_LgIdx)
                action = random.choice(list(candidate_nodes))
                lg_trajectory.append(action)
                label_LgIdx.pop(label_LgIdx.index(action))
                if len(label_LgIdx) == 0:
                    lg_trajectory.append(-1)
                    break
        elif len(label_LgIdx) == 0:
            lg_trajectory.append(-1)

        gt_trajectory = rc_trajectory + lg_trajectory
        return gt_trajectory

    def to_dict(self):
        return vars(self)


