from brl.utils import *
from brl.envs.seq2seq import *
import torch


class HeuristicFunction:
    def __init__(self,
                 model_fname,  # a model file saved by pytorch 1.7.1
                 src_tokenizer,
                 trg_tokenizer,
                 max_len_ratio,
                 use_cpu=False
                 ):
        self.encoder_decoder = torch.load(model_fname, map_location=torch.device('cpu' if use_cpu else 'cuda'))
        if type(self.encoder_decoder) is dict: self.encoder_decoder = self.encoder_decoder['model']  # hack: for loading single-model files in longtu's code
        self.encoder_decoder = self.encoder_decoder.eval()

        self.src_tokenizer, self.trg_tokenizer = src_tokenizer, trg_tokenizer
        #assert len(self.src_tokenizer) == self.transformer.src_token_space_size
        #assert len(self.trg_tokenizer) == self.transformer.trg_token_space_size
        self.max_len_ratio = max_len_ratio

        # for caching purpose, not used in current code
        self.encoder_output = None
        self.cache = None

    def __call__(self, state) -> torch.tensor:
        X, Y, t = state[0], state[1], len(state[1])
        encoder_decoder = self.encoder_decoder

        if t == 0:
            assert False  # should not reach here in current code
            probs = onehot_encoding(self.trg_tokenizer.bos_id, len(self.trg_tokenizer))
            with np.errstate(divide='ignore'):
                lprobs = torch.as_tensor(np.log(probs), dtype=torch.float, device=encoder_decoder.device)

        elif t >= len(encoder_decoder.trg_emb.pe) or t > len(X) * self.max_len_ratio:
            probs = onehot_encoding(self.trg_tokenizer.eos_id, len(self.trg_tokenizer))
            with np.errstate(divide='ignore'):
                lprobs = torch.as_tensor(np.log(probs), dtype=torch.float, device=encoder_decoder.device)

        else:
            X_tensor = torch.as_tensor([state[0]], device=encoder_decoder.device).T
            Y_tensor = torch.as_tensor([state[1]], device=encoder_decoder.device).T
            with torch.no_grad():
                logits = self.encoder_decoder.forward_batch(X_tensor, Y_tensor)
            logits = logits[-1]  # we only need the output for the last token
            lprobs = logits.log_softmax(-1)[0]

            # lprobs = self.forward(state)
            # lprobs_from_forward_incremental = self.forward_incremental(state)
            # assert lprobs_from_forward_incremental.allclose(lprobs)

        return lprobs

    '''
    def forward(self, state):
        X, Y, t = state[0], state[1], len(state[1])
        encoder_decoder = self.encoder_decoder

        X_tensor = torch.as_tensor([X], device=encoder_decoder.device).T
        Y_tensor = torch.as_tensor([Y], device=encoder_decoder.device).T
        with torch.no_grad():
            X_emb = encoder_decoder.src_emb(X_tensor)
            encoder_output = encoder_decoder.encoder(X_emb)
            Y_emb = encoder_decoder.trg_emb(Y_tensor)
            decoder_output = encoder_decoder.decoder(Y_emb, encoder_output)
            #lprobs = encoder_decoder.proj(decoder_output[-1]).log_softmax(-1)[0]
            logits = encoder_decoder.proj(decoder_output)
            logits = logits[-1]
            lprobs = logits.log_softmax(-1)[0]
            #lprobs = encoder_decoder.proj(decoder_output).log_softmax(-1)[0][0]

        return lprobs

    # TODO: implement model caching inside self.encoder_decoder
    def forward_incremental(self, state):
        X, Y, t = state[0], state[1], len(state[1])
        transformer = self.nn_model

        if t == 1:
            x_tensor = Dataset.pad([X[1:-1]], transformer.src_tok, device=transformer.device)
            with torch.no_grad():
                x_emb = transformer.src_emb(x_tensor)  # x_tensor is a column tensor with BOS and EOS
                self.encoder_output = transformer.encoder(x_emb)
                self.cache = transformer.decoder.init_cache(1, 512)

        y_last = state[1][-1]
        with torch.no_grad():
            y_last_emb = transformer.trg_emb(torch.tensor([y_last], device=transformer.device), t-1)
            decoder_output, self.cache = transformer.decoder.forward_step(y_last_emb, self.encoder_output, self.cache)
            lprobs = transformer.proj(decoder_output).log_softmax(-1)[0][0]

        return lprobs
    '''

def greedy_search(env, heuristic):
    state, _ = env.reset()
    #env.render()
    t = env.cur_episode_time
    done = False
    while not done:
        scores = heuristic(state)
        score_t, y_t = scores.topk(1)
        score_t, y_t = float(score_t), int(y_t)
        action = y_t
        state, reward, done, info = env.step(action)
        #env.render()
        t += 1
    assert done
    return reward


def beam_search(original_env, heuristic, beam_size, Node):
    original_env.reset()
    original_env.step(original_env.BOS)
    pool = [Node(env=original_env, action=None, lprob=0.)]  # (X,[BOS])
    max_len = 2000

    for t in interval(1,max_len):
        assert t == 1 or len(pool) == beam_size
        candidates = Filter(max_size=beam_size)
        for node in pool:
            if node.env.cur_done:
                assert node.action is None
                candidates.append(node)
            else:
                top_tokens = heuristic(node.env.cur_state).topk(beam_size)
                assert len(top_tokens.values) == beam_size
                for i in range(beam_size):
                    lprob_t, y_t = float(top_tokens.values[i]), int(top_tokens.indices[i])
                    if hasattr(node, 'nregret'):  # collect sufficient statistics for regret-based beam search
                        nregret_t = lprob_t - Node.alpha * float(top_tokens.values[0])
                        child_node = Node(env=node.env, action=y_t, nregret=node.nregret + nregret_t)
                    else:
                        child_node = Node(env=node.env, action=y_t, lprob=node.lprob + lprob_t)
                    candidates.append(child_node)
                    #if lprob_t == 0.: break  # all the rest child nodes will have lprob = -inf

        assert len(candidates) == beam_size
        pool = []
        best_node = None
        num_complete_translations = 0
        for node in candidates:
            if node.env.cur_done:
                assert node.action is None
                new_node = node
            else:
                env = copy.copy(node.env)
                env.step(node.action)
                node.env, node.action = env, None
            pool.append(node)

            # collect sufficient statistics for termination judgement
            best_node = node if best_node is None or best_node < node else best_node
            if node.env.cur_done: num_complete_translations += 1

        # termination judgement
        if num_complete_translations == beam_size:
        #if best_node.env.cur_done:
            return best_node.env

    # the translation should not have reached the max_len limit
    assert False

def vanilla_beam_search(original_env, heuristic, beam_size):
    class Node:
        def __init__(self, env, action, lprob):
            self.env, self.action, self.lprob = env, action, lprob

        def __lt__(self, other):
            return self.lprob < other.lprob

    return beam_search(original_env, heuristic, beam_size, Node)

def beam_search_with_length_penalty(original_env, heuristic, beam_size, lp_expoent):
    """
    beam search that the "Transformer paper" (https://arxiv.org/abs/1706.03762) used, in which
    the length penalty trick was borrowed from the "GNMT paper" (https://arxiv.org/abs/1609.08144)
    """
    class Node:
        def __init__(self, env, action, lprob):
            self.env, self.action, self.lprob = env, action, lprob

        def score(self):
            Y_len = len(self.env.cur_state[1])-1  # excluding BOS but including EOS
            lp = ((Y_len + 5) / 6) ** lp_expoent
            return self.lprob / lp

        def __lt__(self, other):
            return self.score() < other.score()

    return beam_search(original_env, heuristic, beam_size, Node)

def beam_search_with_regret(original_env, heuristic, beam_size, alpha):
    class Node:
        alpha = None

        def __init__(self, env, action, lprob=None, nregret=0.):
            self.env, self.action, self.lprob, self.nregret = env, action, lprob, nregret

        def __lt__(self, other):
            return self.nregret < other.nregret

    Node.alpha = alpha
    return beam_search(original_env, heuristic, beam_size, Node)


N = 40000

class TreeNode(object):
    def __init__(self, id, reward=float('inf'), cnt=0):
        self.id = id
        self.reward = reward
        self.cnt = cnt
        self.sqrt_cnt = 1e-10
        self.children = None

# simulated reward setting
def PlayOut(node):
    if node.id == 0:
        return 2.0
    elif node.id <10:
        return 1.0
    else:
        return 0.0

# simulated reward setting + vanilla expansion
def Expand(node):
    assert node.children is None
    if node.id == -1:
        node.children = [TreeNode(i) for i in range(N)]
        return True
    else:
        return False

# UCB selection
def Select(node):
    assert node.children is not None
    if node.cnt == 0: return node.children[0]

    baseline = 1.5 * math.sqrt(math.log(node.cnt))
    scores = [candidate.reward + baseline / candidate.sqrt_cnt for candidate in node.children]
    return node.children[np.argmax(scores)]

# vanilla update
def Update(node, reward):
    node.reward = reward if node.cnt == 0 else (node.reward * node.cnt + reward) / (node.cnt + 1)
    node.cnt += 1
    node.sqrt_cnt = math.sqrt(node.cnt)


def MCTS(root, num_rollout):
    for rollout in range(num_rollout):
        path = [root]
        while(path[-1].children is not None):
            path.append(Select(path[-1]))

        is_not_leaf = Expand(path[-1])
        if is_not_leaf:  # a variant that avoid calling pytorch in rollout 0
            path.append(Select(path[-1]))
        reward = PlayOut(path[-1])

        for node in reversed(path):
            Update(node, reward)

        #logging
        if rollout % 1000 == 0:
            print("{} : {}".format(rollout, path[-1].id))




if __name__ == "__main__":
    use_cpu = False
    def deterministic(seed):
        # Set random seeds
        os.environ['PYTHONHASHSEED'] = str(seed)
        random.seed(seed)
        np.random.seed(seed)
        import torch
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
    deterministic(1234)

    # for debugging only
    # DONE: comment out the following code block before check in (and switch the flag between 'DONE' and 'TO-DO')
    #use_cpu = True
    # root = TreeNode(-1)
    # MCTS(root, 1000000)

    beam_size = int(sys.argv[1]) if len(sys.argv) > 1 else 2
    model_fname = sys.argv[2] if len(sys.argv) > 2 else './transformer_longtu/94500-100500.notok.ckpt'
    print('beam size = ', beam_size)

    env = Translation_v1(
        env_folder='./wmt14_en2de_moses_bpe_37k',
        src_fname='newstest2014.en',
        ref_fname='newstest2014.de',
        vocab_src_fname='bpe.share.37000',
        vocab_trg_fname='bpe.share.37000'  # use shared en-de vocab for wmt'14 data
    )

    sys.path.append('./transformer_longtu')  # path of the transformer module
    heuristic = HeuristicFunction(model_fname, env.src_tok, env.trg_tok, max_len_ratio=2, use_cpu=use_cpu)

    rv_sbleu = RV('sentence-bleu')
    hyps = []
    hyps_file = open('vanilla_beam_search_bs'+str(beam_size)+'.de', 'w', encoding='utf-8')

    for i in range(env.num_instances):
        ## for debugging
        #if i < 13:
        #    env.reset()
        #    continue
        #print(i)

        if beam_size == 1:
            env_copy = copy.deepcopy(env)
            greedy_search(env, heuristic)
            assert env.cur_done
            Y_greedy = env.cur_state[1]
            env = env_copy

        env = vanilla_beam_search(env, heuristic, beam_size=beam_size)
        #env = beam_search_with_length_penalty(env, heuristic, beam_size=beam_size, lp_expoent=0.6)
        #env = beam_search_with_regret(env, heuristic, beam_size=beam_size, alpha=1.5)
        assert env.cur_done
        Y_bs = env.cur_state[1]

        if beam_size == 1:
            if Y_greedy != Y_bs:
                print('Y_gd[{}]: {} {}'.format(i, len(Y_greedy), Y_greedy))
                print('Y_bs[{}]: {} {}'.format(i, len(Y_bs), Y_bs))
                #assert False

        Y = Y_bs
        perf = env.cur_reward
        rv_sbleu.append(perf)
        hyp = '' if Y == [env.BOS, env.EOS] else env.trg_tok.ids2str(Y[1:-1])[0]  # yttm.BPE cannot handle empty string
        hyps.append(hyp)
        print(hyp, file=hyps_file, flush=True)
        if i % 100 == 0: print(i, file=sys.stderr)
        if i < 3: print('Y[{}] = {}'.format(i, Y), file=sys.stderr)

    bleu_c = sacrebleu.corpus_bleu(hyps, [env.ref])
    print(bleu_c.format())
    print(rv_sbleu)
    print('corpus-bleu: {}\tsentence-bleu: {}'.format(bleu_c.score, rv_sbleu.mean()))

    print('finished')

