from brl.utils import *
import gym
import sacrebleu


class Seq2Seq_v1(gym.Env):
    """
    An MDP formulation for sequence-tranduction tasks, which generally ask to respond a "input sequence"
        X = [BOS, x_1, ..., x_N, EOS]
    with a "output sequence"
        Y = [BOS, y_1, ..., y_L, EOS]
    where N may vary between inputs, and L is determined by the agent (after receiving X). x_t and y_t belong to given
    "token spaces" V_in and V_out, respectively. The quality of the response is evaluated based on the pair (X,Y),
    either automatically by a program routine, or by a human evaluator.

    In this environment, the output sequence Y is generated through a sequential decision episode that starts with a
    "dummy" initial state s_0 = (X, []), under which any action a_0 leads to s_1 = (X, [BOS]). Based on s_1, the agent
    generates a token a_1 in V_out, which then transits the state to s_2 = (X, [BOS, a_1]), from which the agent gives
    a_2. The episode terminates when a_t* = EOS at some time step t*>0, after which the agent receives terminal state
        s_{t*+1} = (X, [BOS, a_1, ..., a_{t*-1}, EOS])
    along with a once-per-episode reward
        r_{t*+1} = R(X, [BOS, a_1, ..., a_{t*-1}, EOS]).
    The "episode length" of such an episode is defined as T=t*+1, which is also the length of the output sequence
        Y = [BOS, a_1, ..., a_{T-2}, EOS].
    Note that an episode of length T gives an output sequence of length L=T-2, with y_t = a_t for 1<=t<=L.

    Our goal is to maximize the payoff averaged over a fixed distribution of input sequence X, as prescribed by reset().
    """


class Translation_v1(Seq2Seq_v1):
    """
    A parallel-corpus based neural machine translation (NMT) environment.

    The input sequence is a complete sentence of a source language, with tokens from the source-language vocabulary V_src.
    The output sequence is a complete translation of the source sentence into a target language, with tokens from the
    target-language vocabulary V_trg. For some language pairs, the vocabularies can overlap, or even be the same joint vocabulary.

    By default, we use sentence-level BLEU as the episode-wise payoff
        R(X,Y) = sBLEU( Z(X), [BOS, a_1, ..., a_L, EOS] )
    where Z(X) is a "reference translation" of the given source sentence X. Note that the reference Z can be a sequence
    of a different vocabulary from V_trg, and it's the reward function sBLEU's job to align Y and Z to the same
    vocabulary (wihch is again not necessarily V_trg) before comparing them.

    The distribution of X (i.e. the initial state distribution) is represented by a sample of source sentences, as given
    by the file 'src_fname', whose corresponding reference translations are given by the file 'ref_fname'. By default,
    we use WMT14-stanford data, a moses+BPE@37k joint vocabulary.
    """
    def __init__(self,
                 env_folder,
                 src_fname,
                 ref_fname,
                 vocab_src_fname,
                 vocab_trg_fname,
                 max_len=2048,
                 max_len_ratio=2,
                 ):
        self.src, self.ref, self.src_tok, self.trg_tok = self._load_data(
            env_folder,
            os.path.join(env_folder, src_fname),
            os.path.join(env_folder, ref_fname),
            os.path.join(env_folder, vocab_src_fname),
            os.path.join(env_folder, vocab_trg_fname)
        )

        self.num_instances = len(self.src)
        assert self.num_instances == len(self.ref)
        print('|vocab_src|: {} , BOS={} , EOS={}'.format(len(self.src_tok), self.src_tok.bos_id, self.src_tok.eos_id))
        print('|vocab_trg|: {} , BOS={} , EOS={}'.format(len(self.trg_tok), self.trg_tok.bos_id, self.trg_tok.eos_id))
        self.BOS = self.trg_tok.bos_id
        self.EOS = self.trg_tok.eos_id
        #for Z in self.ref: assert self.BOS == Z[0] and self.EOS == Z[-1]
        #for X in self.src: assert self.BOS == X[0] and self.EOS == X[-1]
        self.max_len = max_len
        self.max_len_ratio = max_len_ratio

        self.sample_ids = iter([])
        self.cur_sample_id = None
        self.cur_episode_time = None
        self.cur_state = ([], [self.BOS, self.EOS])
        self.cur_reward = 0.
        self.cur_done = True
        self.visual = None

    # TODO: implement a real seeding routine
    def seed(self, seed):
        """
        this environment has deterministic dynamics, so seeding only affects the randomness in reset()
        """
        pass

    def reset(self):
        self.cur_sample_id = None
        self.cur_episode_time = 0
        self.cur_state = ([self.src_tok.bos_id, self.src_tok.eos_id], [self.BOS, self.EOS])
        self.cur_reward = 0.
        self.cur_done = True
        info = {'dummy step 0': True}
        return self.cur_state, info

    def step(self, action):
        """
        deterministic state transition from s_0 = (X_i, []) to s_1 = (X_i, [BOS]) under any 'action', and from
        s_t = (X_i, [BOS,y_1...y_{t-1}]) to s_{t+1} = (X_i, [BOS, y_1...y_{t-1}, 'action']) for t>=1

        the episode terminates when 'action' is EOS, at which point a sequence-level evaluation metric is used to give a
        one-shot reward to the whole episode.

        :param action: a token id in vocab_trg
        :return: (s_{t+1}, reward, done, info), where reward=0, done=False if action != EOS, otherwise done=True and reward=...
        """
        t = self.cur_episode_time
        info = {}

        if self.cur_done:
            _, info = self._new_episode()
            return self.cur_state, self.cur_reward, self.cur_done, info

        if t >= self.max_len or t > len(self.cur_state[0]) * self.max_len_ratio:
            action = self.EOS
            info['time out'] = t

        s_next = (self.cur_state[0], self.cur_state[1] + [action])
        assert len(s_next[1]) == t + 1  # t=t*, so we are double checking that L+2=T

        done = True if action == self.EOS else False

        reward = self._evaluate(s_next) if done else 0

        self.cur_episode_time = t + 1
        self.cur_state = s_next
        self.cur_reward = reward
        self.cur_done = done
        return s_next, reward, done, info

    def render(self, mode='matplotlib'):  # mode='none'
        """
        visualize the current source sentence X_i, the reference sentence Z_i, the current (possibly partial)
        output sequence Y_{i,t}, as well as the episode-wise reward when the output Y is complete.

        mode: 'console', or 'matplotlib'
        """
        if mode is 'none': return

        str = ''
        if self.cur_state is None:
            str += 'nothing to show -- the environment is not initialized yet.'
        else:
            X, Y_t = self.cur_state
            if len(Y_t)>0:
                assert Y_t[0] == self.BOS
                Y_t = Y_t[1:]
            if len(Y_t)>0 and Y_t[-1]==self.EOS:
                assert len(Y_t)==1 or Y_t[-2] != self.EOS
                Y_t = Y_t[:-1]

            str += 'X: {}\n'.format(self.src_strs[self.cur_sample_id])
            str += 'Z: ({})\n'.format(self.ref[self.cur_sample_id])
            str += 'Y: {}'.format(self.trg_tok.ids2str(Y_t))
            if self.cur_done:
                str += '\nevaluation: {}'.format(self._evaluate(self.cur_state))

        if mode == 'matplotlib':
            if self.visual is None:
                self.visual = LinePlot('instance_id', 'score',
                                       title='source sentence\n(reference translation)\nhypothesis translation\nscore')
            if self.cur_done:
                self.visual.add_point(self.cur_sample_id, self.cur_reward)
            self.visual.title(str)
            self.visual.show()
        elif mode == 'console':
            print('\n'+str)

    def close(self):
        """
        nothing special to do (maybe close the visualization window if there is one?)
        """
        pass

    def _load_data(self, env_folder, src_fname, ref_fname, vocab_src_fname, vocab_trg_fname):
        self.src_lang = src_fname.split('.')[-1]
        self.trg_lang = ref_fname.split('.')[-1]
        print('{}: {} -> {}'.format(src_fname.split('.')[0], self.src_lang, self.trg_lang))

        sys.path.append(env_folder)
        from tokenizer import Tokenizer
        self.src_tok = Tokenizer(self.src_lang, vocab_src_fname)
        self.trg_tok = Tokenizer(self.trg_lang, vocab_trg_fname)

        src_strs = []
        for X in open(src_fname, 'r', encoding='utf-8'):
            src_strs.append(X[:-1])  # remove '\n'
        ref_strs = []
        for Z in open(ref_fname, 'r', encoding='utf-8'):
            ref_strs.append(Z[:-1])  # remove '\n'
        print('src[0] : ', src_strs[0])
        print('ref[0] : ', ref_strs[0])
        print('src[{}]: {}'.format(len(src_strs) - 1, src_strs[-1]))
        print('ref[{}]: {}'.format(len(ref_strs)-1, ref_strs[-1]))

        src_idss = self.src_tok.str2ids(src_strs)
        src_idss = [[self.src_tok.bos_id] + ids + [self.src_tok.eos_id] for ids in src_idss]

        self.src = src_idss
        self.ref = ref_strs
        self.src_strs = src_strs
        self.src_fname = src_fname
        self.ref_fname = ref_fname
        self.vocab_src_fname = vocab_src_fname
        self.vocab_trg_fname = vocab_trg_fname
        return self.src, self.ref, self.src_tok, self.trg_tok

    def _new_episode(self):
        """
        initialize an episode by giving a source sequence; will (re-)shuffle the sequence-pairs at the beginning of a
        new epoch, and then go through all sequence pairs one by one

        :return: (X_i, []) as the s_0 of the episode, where X_i is the "current" source sentence to be translated
        """
        i = next(self.sample_ids, None)
        if i is None:
            # beginning of a new epoch
            self.sample_ids = self._sample(self.num_instances, 'no_replace')
            i = next(self.sample_ids, None)
            assert i is not None

        self.cur_sample_id = i
        self.cur_episode_time = 1
        self.cur_state = (self.src[i], [self.BOS])
        self.cur_reward = 0.
        self.cur_done = False
        info = {}
        return self.cur_state, info

    # TODO: replace current dummy round-robin with a real sampling routine (supporting subset sampling with replacement)
    def _sample(self, size, replace='replace'):
        assert size == self.num_instances and replace == 'no_replace'
        return iter(range(self.num_instances))

    # TODO: give a "vector evaluation"
    def _evaluate(self, terminal_state):
        """
        use the default sentence-bleu score (implemented by sacrebleu-1.5.1)
        """
        X = terminal_state[0]
        Y = terminal_state[1]
        ref_str = self.ref[self.cur_sample_id]  # assuming Z(X) is plain text

        assert len(Y)>=2 and Y[0]==self.BOS and Y[-1]==self.EOS
        hyp_str = self.trg_tok.ids2str(Y[1:-1]) if len(Y) > 2 else ''  # yttm.BPE cannot handle empty string
        bleu_s = sacrebleu.sentence_bleu(hyp_str, [ref_str])
        return bleu_s.score



if __name__ == "__main__":
    use_cpu = False
    def deterministic(seed):
        # Set random seeds
        os.environ['PYTHONHASHSEED'] = str(seed)
        random.seed(seed)
        np.random.seed(seed)
        import torch
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
    deterministic(1234)

    # for debugging only
    # DONE: comment out the following code block before check in (and switch the flag between 'DONE' and 'TO-DO')
    # use_cpu = True

    policy = sys.argv[1] if len(sys.argv) > 1 else 'greedy'  #'random_len=x' 'random_len=bs1_longtu' 'beam_search_longtu'
    print('policy = ', policy)
    rander_mode = 'none'  # 'console', 'matplotlib'

    env = Translation_v1(
        env_folder='./wmt14_en2de_moses_bpe_37k',
        src_fname='newstest2014.en',
        ref_fname='newstest2014.de',
        vocab_src_fname='bpe.share.37000',
        vocab_trg_fname='bpe.share.37000'  # use shared en-de vocab for wmt'14 data
    )
    env.render(rander_mode)

    # hack
    #sys.path.append('.')
    sys.path.append('./transformer_longtu')
    from transformer import *
    from data import *

    model_fname = './transformer_longtu/94500-100500.notok.ckpt'
    if use_cpu:
        transformer = torch.load(model_fname, map_location=torch.device('cpu')).eval()
    else:
        transformer = torch.load(model_fname).eval()
    transformer.src_tok, transformer.trg_tok = env.src_tok, env.trg_tok

    def transformer_exp(beam_size):
        score_sbleu = RV('sentence-bleu')
        score_sprob = RV('sentence-likelihood (%)')
        hyps = []
        hyps_file = open(policy+'.'+str(beam_size)+'.de', 'w', encoding='utf-8')

        for i in range(env.num_instances):
            if i % 100 == 0: print(i, file=sys.stderr)

            initial_state, _ = env.reset()
            done = False
            env.render(rander_mode)

            # for debugging
            #i = 292
            #env.cur_state = (env.src[i], [])
            #initial_state = env.cur_state

            # hack: all policies below receive 'initial_state' as input and return 'hyp_ids' (with BOS and EOS)
            if policy == 'random_len=x':
                idss = [initial_state[0][1:-1]]
                hyp_ids = [env.BOS] + \
                          [random.randint(max(env.BOS, env.EOS)+1, len(env.trg_tok)-1) for _ in range(len(idss[0]))] \
                          + [env.EOS]

            elif policy == 'beam_search_longtu':
                idss = [initial_state[0][1:-1]]
                src_index = Dataset.pad(idss, env.src_tok, device=transformer.device)

                with torch.no_grad():
                    lenpen = 0.6
                    score_type = 'score_lprob' # 'score_current'
                    max_len_ratio = 2
                    hyp_index = transformer.beam_search(src_index, beam_size, lenpen, score_type, max_len_ratio)
                idss = Dataset.unpad(hyp_index, env.trg_tok)
                hyp_ids = [env.BOS] + idss[0] + [env.EOS]

            elif policy == 'random_len=bs1_longtu':  # random token sequence of the same length with that of bs_longtu with bs=1
                idss = [initial_state[0][1:-1]]
                src_index = Dataset.pad(idss, env.src_tok, device=transformer.device)
                with torch.no_grad():
                    hyp_index = transformer.beam_search(src_index,
                                                      k=1,  # beam size
                                                      lenpen=0.6,
                                                      score_type='score_lprob',
                                                      max_len_ratio=2)
                idss = Dataset.unpad(hyp_index, env.trg_tok)
                hyp_ids = [env.BOS] + \
                          [random.randint(max(env.BOS, env.EOS)+1, len(env.trg_tok)-1) for _ in range(len(idss[0]))] \
                          + [env.EOS]

            elif policy == 'greedy':
                # bs1_longtu as baseline
                with torch.no_grad():
                    idss = [initial_state[0][1:-1]]
                    src_index = Dataset.pad(idss, env.src_tok, device=transformer.device)
                    hyp_bs1 = transformer.beam_search(src_index,
                                                      k=1,  # beam size
                                                      lenpen=0.6,
                                                      score_type='score_lprob',
                                                      max_len_ratio=2)

                # greedy policy
                Y_lprob = 0.
                state = initial_state
                while not done:
                    X, Y, t = state[0], state[1], len(state[1])
                    if t == 0:
                        x_tensor = Dataset.pad([X[1:-1]], env.src_tok, device=transformer.device)
                        with torch.no_grad():
                            x_emb = transformer.src_emb(x_tensor)  # x_tensor is a column tensor with BOS and EOS
                            encoder_output = transformer.encoder(x_emb)
                            cache = transformer.decoder.init_cache(1, 512)
                        lprob_t, y_t = 0., env.BOS
                    else:
                        y_last = state[1][-1]
                        with torch.no_grad():
                            y_last_emb = transformer.trg_emb(torch.tensor([y_last], device=transformer.device), t-1)
                            decoder_output, cache = transformer.decoder.forward_step(y_last_emb, encoder_output, cache)
                            lprobs = transformer.proj(decoder_output).log_softmax(-1)[0][0]
                            lprob_t, y_t = lprobs.topk(1)
                            lprob_t, y_t = float(lprob_t), int(y_t)
                    if t > len(X)*2:
                        lprob_t, y_t = lprobs[env.EOS], env.EOS

                        # simulate bs1_longtu which will still output argmax_lprob instead of env.EOS here
                        y_t = int(lprobs.topk(1)[1])
                        state, reward, done, info = env.step(y_t)
                        done = True
                        Y_lprob += lprob_t
                        break

                    state, reward, done, info = env.step(y_t)
                    env.render(rander_mode)
                    Y_lprob += lprob_t
                Y = state[1]
                score_sprob.append(math.exp(Y_lprob) * 100.)

                hyp_index = torch.tensor([Y+[env.EOS]], device=transformer.device).T
                if not hyp_index.equal(hyp_bs1):
                    print('i: ', i)
                    print('hyp_bs1_longtu: ', hyp_bs1)
                    print('hyp_greedy: ', hyp_index)
                idss = Dataset.unpad(hyp_index, env.trg_tok)
                hyp_ids = [env.BOS] + idss[0] + [env.EOS]

            else:
                assert False

            hyp = env.trg_tok.ids2str(hyp_ids[1:-1])[0]
            hyps.append(hyp)
            print(hyp, file=hyps_file)

            if done is False:
                for t in range(len(hyp_ids)):
                    action = hyp_ids[t]
                    state, reward, done, info = env.step(action)
                    env.render(rander_mode)
                assert done
            score_sbleu.append(reward)

        if score_sprob.size() > 0: print(score_sprob)
        if score_sbleu.size() > 0: print(score_sbleu)

        '''
        # debug
        hyps = []
        hyps_file = open('transformer.' + str(beam_size) + '.de', 'r', encoding='utf-8')
        for Y in hyps_file:
            hyps.append(Y[:-1])  # remove '\n'
        '''
        bleu_c = sacrebleu.corpus_bleu(hyps, [env.ref])
        print(bleu_c.format())

        bleu_s = []
        for i in range(env.num_instances): bleu_s.append(sacrebleu.sentence_bleu(hyps[i], [env.ref[i]]))
        bleu_s = [bleu.score for bleu in bleu_s]
        bleu_avg = sum(bleu_s) / len(bleu_s)
        print('corpus-bleu: {}\tsentence-bleu: {}'.format(bleu_c.score, bleu_avg))


    for k in [1,2,4,8,16,32]:#,64,128,256,512,1024]:
        print('\nbeam size = ', k)
        transformer_exp(beam_size=k)

    print('finished')

