import os
import copy
import json
import torch
import random
from RetroUtils.smiles2dgl import smiles_to_dglGraph
from hyperGraph_utils.utils import OurHypergraph


class LGCEnv:
    def __init__(self, data_name, ProcessedDataFile_path, max_len=4):  # ProcessedDataFile_path='./data/ProcessedData'
        # hypergraph
        LG_hypergraph_info_path = os.path.join(ProcessedDataFile_path, data_name, 'hypergraph.json')
        with open(LG_hypergraph_info_path, 'r') as f:
            self.hg_json = json.load(f)
        self.hypergraph_dhg = OurHypergraph(num_v=self.hg_json['num_v'],  # int
                                            e_list=self.hg_json["e_list"],  # [[], [], ...]
                                            x_tensor=torch.tensor(self.hg_json['v_fp'],
                                                                  dtype=torch.float32))  # num_v x 2048
        self.max_len = max_len

    def reset(self, data_dict):
        self.data_dict = data_dict
        self.rxn = self.data_dict['rxn']
        self.product = self.data_dict['product']
        self.reactants = self.data_dict['reactants']
        self.label_RcNodeIdx = self.data_dict['rc_node_idx']
        self.label_LgIdx = self.data_dict['lg_idx']

        # state
        self.product_dgl = smiles_to_dglGraph(smiles=self.product, seed=0, add_self_loop=False, Conformer=False)
        self.num_nodes_product = self.product_dgl.num_nodes()
        self.num_nodes_hypergraph = self.hg_json['num_v']
        self.product_fp = torch.tensor(self.data_dict['product_fp'], dtype=torch.float32).reshape(1, -1)

        LgIdx_list = []

        # mask
        mask = torch.zeros(self.num_nodes_hypergraph + 1)  # (v+1)
        mask[-1] = 1
        mask = mask.reshape(1, -1)  # 1x(v+1)

        t = 0

        self.state = [self.product_dgl, self.product_fp, LgIdx_list, mask, t]
        return copy.deepcopy(self.state)

    def infer_reset(self, infer_init_state):
        self.state = infer_init_state

    def compute_Placed_tensor(self, num_nodes, node_set):  # num_nodes: int, node_set: [idx, idx, ...]
        assert -1 not in node_set, '-1 in node_set'
        placed = torch.zeros(num_nodes)  # n
        placed[node_set] = 1
        return placed  # n

    def compute_hypergraph_one_hop_neighbor(self, node_set):  # node_set: [idx, idx, ...]
        one_hop_node_neighbor = set()
        one_hop_edge_neighbor = set()
        for node in node_set:
            one_hop_edges = self.hypergraph_dhg.N_e(node)
            one_hop_edge_neighbor = one_hop_edge_neighbor | set(one_hop_edges.tolist())
        for edge in list(one_hop_edge_neighbor):
            one_hop_nodes = self.hypergraph_dhg.N_v(edge)
            one_hop_node_neighbor = one_hop_node_neighbor | set(one_hop_nodes.tolist())
        one_hop_node_neighbor = one_hop_node_neighbor - set(node_set)
        one_hop_node_neighbor = list(one_hop_node_neighbor)
        one_hop_node_neighbor.sort()
        return one_hop_node_neighbor  # [idx, idx, ...]

    def compute_mask(self, node_set):  # ode_set: [idx, idx, ...], tag: int
        # tag = 1  stage: leaving groups completion
        if len(node_set) == 0:
            mask = torch.zeros(self.num_nodes_hypergraph + 1)
            mask[-1] = 1
            mask = mask.reshape(1, -1)  # 1x(v+1)  # !
        else:
            assert -1 not in node_set, '-1 in node_set'
            one_hop_neighbor = self.compute_hypergraph_one_hop_neighbor(node_set)  # [idx, idx, ...]
            if len(one_hop_neighbor) != 0:
                assert max(one_hop_neighbor) < self.hg_json[
                    'num_v'], 'error at one_hop_neighbor, max(one_hop_neighbor) >= num_nodes_hypergraph'
            mask = torch.ones(self.hg_json['num_v'] + 1)  # (v+1)
            mask[one_hop_neighbor] = 0
            mask[-1] = 0
            mask = mask.reshape(1, -1)  # 1x(v+1)
        return mask

    def step(self, action):
        # tag = 1  stage: leaving groups completion
        if action != -1:  # not lg-stop-action
            assert action < self.hg_json['num_v'], 'action out of num_nodes_hypergraph'
            # change: LgIdx_list, mask, t
            LgIdx_list = copy.deepcopy(self.state[2])
            assert action not in LgIdx_list, f'action: {action} in LgIdx_list: {LgIdx_list}'
            LgIdx_list.append(action)
            LgIdx_list.sort()  # !

            mask = self.compute_mask(node_set=LgIdx_list)  # !
            t = copy.deepcopy(self.state[-1]) + 1  # !

            # not change: product_dgl, product_fp

            next_state = [self.product_dgl, self.product_fp, LgIdx_list, mask, t]

            if next_state[-1] > self.max_len:
                done = 1
            else:
                done = 0

            if done:
                if set(LgIdx_list) == set(self.label_LgIdx):
                    reward = 1
                else:
                    reward = 0
            else:
                reward = 0

            self.state = copy.deepcopy(next_state)
            return reward, copy.deepcopy(next_state), done

        elif action == -1:  # lg-stop-action
            # change: mask, t
            mask = torch.zeros(self.hg_json['num_v'] + 1)
            mask = mask.reshape(1, -1)  # 1x(v+1)  # !
            t = copy.deepcopy(self.state[-1]) + 1  # !

            # not change: product_dgl, product_fp, LgIdx_list
            LgIdx_list = copy.deepcopy(self.state[2])

            next_state = [self.product_dgl, self.product_fp, LgIdx_list, mask, t]

            done = 1
            if set(LgIdx_list) == set(self.label_LgIdx):
                reward = 1
            else:
                reward = 0

            self.state = copy.deepcopy(next_state)
            return reward, copy.deepcopy(next_state), done

    def generate_random_gt_trajectory(self):
        label_LgIdx = copy.deepcopy(self.label_LgIdx)  # [idx, idx, ...]
        lg_trajectory = []

        # lg
        action = random.choice(label_LgIdx)
        lg_trajectory.append(action)
        label_LgIdx.pop(label_LgIdx.index(action))
        if len(label_LgIdx) != 0:
            while True:
                one_hop_neighbor = self.compute_hypergraph_one_hop_neighbor(lg_trajectory)
                candidate_nodes = set(one_hop_neighbor) & set(label_LgIdx)
                action = random.choice(list(candidate_nodes))
                lg_trajectory.append(action)
                label_LgIdx.pop(label_LgIdx.index(action))
                if len(label_LgIdx) == 0:
                    lg_trajectory.append(-1)
                    break
        elif len(label_LgIdx) == 0:
            lg_trajectory.append(-1)

        return lg_trajectory

    def to_dict(self):
        return vars(self)



