import copy
import json
import math
import os
import time

import torch
import tqdm
from rdkit.Chem import MolFromSmiles, MolFromSmarts

from LGCAgent import LGCAgentPPO
from LGCEnv import LGCEnv
from config import Config


class LGCBeamSearch():
    def __init__(self, lgcagent, k):
        self.agent = lgcagent
        self.k = k

    def length_penalty(self, length, alpha=0.6):
        """
        计算长度惩罚值。

        Args:
        length (int): 生成的序列的长度。
        alpha (float): 惩罚的超参数。通常在 [0, 1] 之间，alpha 越大，惩罚越强。

        Returns:
        float: 长度惩罚值。
        """
        return math.pow((5 + length) / 6, alpha)

    def search(self, retro_env, data_dict, additional_mask, unique=False):
        state = retro_env.reset(data_dict=data_dict)
        # print(retro_env.label_LgIdx)
        rst = []
        init_log_p = 0
        init_done = 0
        candidate = [(init_log_p, retro_env, init_done)]

        flag = 0
        while sum([i[2] for i in candidate]) != len(candidate):
            # print('\nCANDIDATE', [(i[0], i[1].state[2], i[2]) for i in candidate])
            tmp_candi = copy.deepcopy(candidate)
            candidate = []
            for candi in tmp_candi:
                if candi[2]:
                    candidate.append(candi)
                    continue

                last_log_p = candi[0]
                action, p_of_action, rst_policy = self.agent.select_action_infer(one_state=candi[1].state,
                                                                                 additional_mask=additional_mask)
                # print(rst_policy)
                all_p, all_actions = torch.sort(rst_policy, descending=True)
                all_actions = all_actions.tolist()
                all_p = all_p.tolist()
                k_actions = all_actions[:self.k]
                k_p = all_p[:self.k]

                # print('For state', candi[1].state[2], 'we have K actions', k_actions, 'K possibilities', k_p)
                k_env = [copy.deepcopy(candi[1]) for _ in range(k)]
                for i, act in enumerate(k_actions):
                    if k_p[i] == 0:
                        continue
                    if act == len(all_actions) - 1:
                        act = -1
                    r, next_state, done = k_env[i].step(act)
                    # print('with action', act, 'next state', r, next_state[2], done, 'current p', k_p[i])

                    candidate.append(
                        (last_log_p + math.log(k_p[i]) / self.length_penalty(len(next_state[2])), k_env[i], done))
            # print('new CANDIDATE', [(i[0], i[1].state[2], i[2]) for i in candidate])
            candidate = sorted(candidate, key=lambda x: x[0], reverse=True)[:self.k]
            # print('we choose k candidates', [(i[0], i[1].state[2]) for i in candidate])

        p_list = []
        idx_list = []
        for (p, env, done) in candidate:
            if env.state[2] in idx_list and unique:
                continue
            else:
                idx_list.append(env.state[2])
                p_list.append(p)
        rst = [i for i in zip(p_list, idx_list)]
        return rst
