# Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/text-generation/run_generation.py

import argparse
import logging
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
import string
import json
import os
from nltk.tokenize import sent_tokenize

from transformers import (
    CTRLTokenizer,
    AutoTokenizer,
)
from src.ctrl import PrefixCTRL
from src.gpt2 import PrefixGPT
from src.gen_utils import sort_score, save
from task_mapping import *

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)


class SuperGenGenerator():

    def __init__(self, args):
        self.args = args
        self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
        # self.tokenizer = CTRLTokenizer.from_pretrained(args.model_name_or_path)
        # print(len(self.tokenizer))
        # print(self.tokenizer.convert_ids_to_tokens([len(self.tokenizer)-1]))
        # print(self.tokenizer.convert_ids_to_tokens([len(self.tokenizer)-2]))
        self.tokenizer.model_max_length = 512
        # self.eos = '<eos>'
        # self.tokenizer._add_tokens([self.eos])
        # self.eos_idx = self.tokenizer.convert_tokens_to_ids(self.eos)
        self.linebreak_idx = self.tokenizer.convert_tokens_to_ids('\n')
        self.model = PrefixCTRL.from_pretrained(args.model_name_or_path)
        # self.model = PrefixGPT.from_pretrained(args.model_name_or_path)
        self.default_mode = self.model.default_mode
        args.task = self.model.config.task
        args.label_list = self.model.config.label_list
        self.label_map = {label: i for i, label in enumerate(args.label_list)}
        print(f"task: {args.task}; label: {args.label}")
        self.model.to(args.device)
        if args.fp16:
            self.model.half()
        self.set_seed(args.seed)
        self.task_type = task_type_mapping[args.task]
        # self.stop_tokens = stop_tokens_mapping[args.task]
        self.stop_token = self.tokenizer.eos_token
        self.control_code = control_code_mapping[args.task] if 'ctrl' in args.model_name_or_path else None
        self.prompt = prompt_mapping[args.task][args.label]
        self.repetition = repetition_mapping[args.task][args.label]
        self.bad_tokens = bad_tokens_mapping[args.task]
        self.fix_start = fix_start_mapping[args.task] if args.task in fix_start_mapping else None
        self.extract_remain = args.task in extract_remaining
        self.allow_new_line = args.task in allow_start_new_line
        if self.extract_remain:
            for label in prompt_mapping[args.task]:
                if prompt_mapping[args.task][label] == "...":
                    self.remain_label = label
                else:
                    self.prompt = prompt_mapping[args.task][label]
                    self.prompt_label = label
        if self.task_type == "pair":
            assert args.temperature == 0
            assert args.pretrain_corpus_dir is not None
            f = open(args.pretrain_corpus_dir)
            texts = f.readlines()
            texts = [text.strip() for text in texts]
            chosen_idx = np.random.choice(len(texts), args.num_gen, replace=False)
            self.sampled_texts = [texts[i] for i in chosen_idx]
        else:
            if args.task not in vary_temperature:
                assert args.temperature > 0
            self.sampled_texts = None
        self.prompt_list = self.prompt if type(self.prompt) == list else [self.prompt]
        if args.task in vary_temperature:
            if type(args.temperature) != list:
                self.temp = vary_temperature[args.task]
            else:
                self.temp = args.temperature
            self.do_sample = True
        elif args.temperature == 0:
            self.temp = 1
            self.do_sample = False
        else:
            self.temp = args.temperature
            self.do_sample = True

    def set_seed(self, seed):
        np.random.seed(seed)
        torch.manual_seed(seed)
        if self.args.n_gpu > 0:
            torch.cuda.manual_seed_all(seed)

    def prepare_input(self, prompt_text):
        encoded_prompt = self.tokenizer.encode(prompt_text, add_special_tokens=False)
        if not any(encoded_prompt[0] == x for x in self.tokenizer.control_codes.values()):
            logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
        return prompt_text

    def compose_one(self, example, end_char=None):
        if end_char is None:
            end_char = ' ' + self.eos + ' ' + self.eos + ' '
            # end_char = ' ' + self.eos + ' '
        prompts = prompt_mapping[self.args.task]
        prompt_list = prompts[example.label]
        prompt_list = prompt_list if type(prompt_list) == list else [prompt_list]
        choice_idx = np.random.choice(len(prompt_list), 1)
        prompt = prompt_list[choice_idx[0]]
        start = ''
        if type(prompt) == list:
            assert len(prompt) == 2 and example.text_b is not None
            start_prompt = prompt[0]
            conj_prompt = prompt[1]
            lowercase_text_a = True
        else:
            start_prompt = prompt if example.text_b is None else None
            conj_prompt = None if example.text_b is None else prompt
            lowercase_text_a = False
        if start_prompt is not None and len(start_prompt) > 0:
            start += start_prompt + ' '
        if lowercase_text_a:
            text_a = example.text_a[0].lower() + example.text_a[1:]
        else:
            text_a = example.text_a
        start += text_a + ' '
        # append conjunction prompt if any
        if conj_prompt is not None and len(conj_prompt) > 0:
            start += conj_prompt + ' '
        if example.text_b is not None:
            if conj_prompt is not None:
                start += example.text_b[0].lower() + example.text_b[1:]
            else:
                start += example.text_b
        return start + end_char

    def generate_one(self, seed, label, examples=[], sample_text=None):
        self.set_seed(seed)
        start = ''
        # always start with control codes (when generator is CTRL)
        if self.default_mode == 'full' and 'ctrl' in self.model.__name__.lower():
            start = self.control_code + ' '
        for example in examples:
            demo = self.compose_one(example)
            start += demo
        prompts = prompt_mapping[self.args.task]
        prompt_list = prompts[label]
        repetition = repetition_mapping[self.args.task][label]
        if len(repetition) == 2:
            repetition_penalty = repetition[1]
            repetition_reward = repetition[0]
        else:
            repetition_penalty = repetition[0]
            repetition_reward = repetition[0]
        # prompt_list = prompt_list if type(prompt_list) == list else [prompt_list]
        # choice_idx = np.random.choice(len(prompt_list), 1)
        # prompt = prompt_list[choice_idx[0]]
        prompt = prompt_list
        if type(prompt) == list:
            assert len(prompt) == 2 and sample_text is not None
            start_prompt = prompt[0]
            conj_prompt = prompt[1]
            lowercase_sampled = False
        else:
            start_prompt = prompt if sample_text is None else None
            conj_prompt = None if sample_text is None else prompt
            lowercase_sampled = False
        
        if self.default_mode != 'full':
            start_prompt = None
        if 'no-prompt' in self.default_mode:
            conj_prompt = '[BOS]'

        # append start prompt if any
        if start_prompt is not None and len(start_prompt) > 0:
            start += start_prompt + ' '
        prompt_text = start

        # append sample text if any
        reward_span = None
        if sample_text is not None:
            orig_sample_text = sample_text
            if lowercase_sampled:
                sample_text = orig_sample_text[0].lower() + orig_sample_text[1:]
            else:
                sample_text = orig_sample_text
            start += sample_text + ' '
            if repetition_penalty != repetition_reward:
                encoded_prompt = self.tokenizer.encode(
                    prompt_text, add_special_tokens=False, return_tensors="pt",
                )
                encoded_first_sent = self.tokenizer.encode(
                    start, add_special_tokens=False, return_tensors="pt",
                )
                reward_span = torch.tensor([len(encoded_prompt[0]), len(encoded_first_sent[0])])
        
        # append conjunction prompt if any
        infix_pos = []
        if conj_prompt is not None and len(conj_prompt) > 0:
            start_after = start + conj_prompt + ' '
            if 'infix' in self.default_mode:
                encoded_prompt_before = self.tokenizer.encode(
                    start, add_special_tokens=False, return_tensors="pt",
                )
                infix_pos += [0] * len(encoded_prompt_before[0])
                encoded_prompt_after = self.tokenizer.encode(
                    start_after, add_special_tokens=False, return_tensors="pt",
                )
                infix_pos += [1] * (len(encoded_prompt_after[0]) - len(encoded_prompt_before[0]))
            start = start_after
        
        # append fixed start tokens if any
        if self.fix_start is not None:
            choice_idx = np.random.choice(len(self.fix_start), 1)
            start_words = self.fix_start[choice_idx[0]]
            start += start_words + ' '
        else:
            start_words = None
        
        preprocessed_start_text = start
        # print(f'start: {repr(preprocessed_start_text)}')
        encoded_start = self.tokenizer.encode(
            preprocessed_start_text, add_special_tokens=False, return_tensors="pt",
        )
        encoded_start = encoded_start.to(self.args.device)
        if encoded_start.size()[-1] == 0:
            input_ids = None
        else:
            input_ids = encoded_start
            if 'infix' in self.default_mode:
                infix_pos += [0] * (len(input_ids[0]) - len(infix_pos))
                infix_pos = torch.tensor(infix_pos).unsqueeze(0).to(input_ids)
            else:
                infix_pos = None
        # input_ids[input_ids == self.eos_idx] = self.linebreak_idx
        # print(f"input_ids: {input_ids}")
        # print(f"tokens: {[self.tokenizer.convert_ids_to_tokens(ids) for ids in input_ids]}")
        # print(f"infix_pos: {infix_pos}")
        if sample_text is not None:
            max_len = len(input_ids[0]) + self.args.max_len
            # if len(input_ids[0]) > 1.5 * self.args.max_len:
            #     return None
        else:
            max_len = self.args.max_len
        if type(self.temp) == list:
            choice_idx = np.random.choice(len(self.temp), 1)
            temp = float(self.temp[choice_idx[0]])
        else:
            temp = self.temp
        
        # print(f"label: {label}; {repetition_penalty}, {repetition_reward}")
        # print(input_ids)
        cat_label = torch.tensor([self.label_map[label]]).to(self.model.device)
        # print(f"label: {cat_label}")
        outputs = self.model.generate(
            input_ids=input_ids,
            infix_pos=infix_pos,
            cat_label=cat_label,
            reward_span=reward_span,
            max_length=max_len,
            temperature=temp,
            top_k=self.args.k,
            top_p=self.args.p,
            repetition_penalty=repetition_penalty,
            repetition_reward=repetition_reward,
            do_sample=self.do_sample,
            num_return_sequences=1,
            output_scores=True,
            return_dict_in_generate=True,
        )
        output_sequences = outputs["sequences"][0]
        # print(output_sequences)
        tokens = [self.tokenizer.convert_ids_to_tokens(wid.item()) for wid in output_sequences]
        scores = outputs["scores"]

        generated_sequence = output_sequences
        generated_sequence = generated_sequence.tolist()
        # print(f"generated_sequence: {generated_sequence}")
        # print(f"tokens: {[self.tokenizer.convert_ids_to_tokens(ids) for ids in generated_sequence]}")

        # Decode text
        text = self.tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
        # print(repr(text))
        start = self.tokenizer.decode(encoded_start[0], clean_up_tokenization_spaces=True) #.replace(self.eos, "\n")
        start_len = len(start)

        if not self.allow_new_line and (text[start_len:].startswith("\n") or text[start_len:].startswith(" \n")):
            return None
        # print(f"start: {start}")
        # skip_len = len(start)
        # print(f"trunc: {text[skip_len:]}")
        # if "Julia Stinshoff" in start:
        #     print_ = True
        # else:
        #     print_ = False
        final_stop_idx = text.find(self.stop_token)
        # Remove all text after the stop token
        trunc_text = text[:final_stop_idx]
        # if len(self.stop_tokens) > 0:
            
        #     final_stop_idx = self.find_stop_idx(text, skip_len, self.stop_tokens)
        #     # Remove all text after the stop token
        #     trunc_text = text[:final_stop_idx]
        #     if self.extract_remain:
        #         remain_text = text[final_stop_idx:]
        #         sents = sent_tokenize(remain_text.strip())
        #         if len(sents) > 1:
        #             select_idx = np.random.choice(len(sents)-1, 1)
        #             remain_text = sents[select_idx[0]]
        #             extra_sequence = remain_text.strip()
        #         else:
        #             return None
        if final_stop_idx == -1:
            return None

        total_sequence = (trunc_text[start_len:])
        total_sequence = total_sequence.strip()
        for bad_token in self.bad_tokens:
            if bad_token in total_sequence:
                return None
            # if self.extract_remain and bad_token in extra_sequence:
            #     return None
        if len(total_sequence) == 0:
            return None
        start_idx = len(input_ids[0])
        num_skip = 0
        if self.allow_new_line:
            while tokens[start_idx] == '\n':
                num_skip += 1
                start_idx += 1
        # assert total_sequence.startswith(tokens[start_idx].split('@@')[0]), f"total_sequence: {total_sequence}; start_token: {tokens[start_idx]}"
        total_sequence_split = total_sequence.split(' ')
        j = 0
        # subtoken = ''
        # valid_flag = True
        # for i, token in enumerate(tokens[start_idx:]):
        #     if j == len(total_sequence_split):
        #         break
        #     if subtoken + token != total_sequence_split[j]:
        #         try:
        #             assert token.endswith('@@') or total_sequence_split[j][-1] in string.punctuation
        #         except AssertionError:
        #             valid_flag = False
        #             break
        #         subtoken += token.split('@@')[0]
        #     else:
        #         subtoken = ''
        #         j += 1
        # if valid_flag == False:
        #     return None

        # with torch.no_grad():
            # scores = scores[num_skip:num_skip+i]
            # scores = torch.cat(scores, dim=0) * temp
            # token_ids = output_sequences[start_idx:i+start_idx]
            # print(f"token_ids: {token_ids}")
            # print(f"tokens: {self.tokenizer.convert_ids_to_tokens(token_ids)}")
            # probs = F.log_softmax(scores, dim=-1)
            # token_probs = probs.gather(dim=-1, index=token_ids.unsqueeze(-1)).mean()
        if start_words is not None:
            gen_text = start_words + ' ' + total_sequence
        else:
            gen_text = total_sequence
        if sample_text is not None:
            res = {"text1": orig_sample_text, 
                   "text2": gen_text, 
                   "label": label,
                   "start_prompt": prompt_text, 
                   "conj_prompt": conj_prompt, 
                #    "score": token_probs.item()
                   }
            if self.args.print_res:
                print(res)
        else:
            res = {"text": gen_text, 
                   "label": label,
                   "start_prompt": prompt_text,
                #    "score": token_probs.item()
                   }
            if self.args.print_res:
                print(res)
        if self.extract_remain:
            res["extra"] = extra_sequence
        return res

    def save_res(self, gen_res, sort):
        os.makedirs(self.args.save_dir, exist_ok=True)
        if self.extract_remain:
            gen_prompt_res = []
            gen_extra_res = []
            for res in gen_res:
                prompt_res = {k: v for k, v in res.items() if k != "extra"}
                prompt_res["label"] = self.prompt_label
                gen_prompt_res.append(prompt_res)
                extra_res = {k: v for k, v in res.items() if k != "extra"}
                extra_res["label"] = self.remain_label
                extra_res["text2"] = res["extra"]
                gen_extra_res.append(extra_res)
            save_name = os.path.join(self.args.save_dir, f"{self.args.task}_{self.prompt_label}_{self.args.num_gen}")
            with open(f"{save_name}.json", 'w') as f:
                res = json.dumps(gen_prompt_res)
                f.write(res)
                f.close()
            if sort:
                new_dict = sort_score(f"{save_name}.json")
                save(f"{save_name}_sorted.json", new_dict)
                print(f"saved to {save_name}_sorted.json")
            save_name = os.path.join(self.args.save_dir, f"{self.args.task}_{self.remain_label}_{self.args.num_gen}")
            with open(f"{save_name}.json", 'w') as f:
                res = json.dumps(gen_extra_res)
                f.write(res)
                f.close()
            if sort:
                new_dict = sort_score(f"{save_name}.json")
                save(f"{save_name}_sorted.json", new_dict)
                print(f"saved to {save_name}_sorted.json")
        else:
            save_name = os.path.join(self.args.save_dir, f"{self.args.task}_{self.args.label}_{self.args.num_gen}")
            with open(f"{save_name}.json", 'w') as f:
                res = json.dumps(gen_res)
                f.write(res)
                f.close()
            if sort:
                new_dict = sort_score(f"{save_name}.json")
                save(f"{save_name}_sorted.json", new_dict)
                print(f"saved to {save_name}_sorted.json")

    def generate_all_demo_all(self, label, examples={}, use_demo=1, random=False):
        gen_res = []
        for seed in tqdm(range(self.args.num_gen)):
            if self.sampled_texts is None:
                sample_text = None
            else:
                sample_text = self.sampled_texts[seed]
            if len(examples) > 0:
                all_examples = []
                for demo_label in examples:
                    demo_examples = examples[demo_label]
                    choice_idx = np.random.choice(len(demo_examples), use_demo, replace=False)
                    input_examples = [demo_examples[i] for i in choice_idx]
                    all_examples += input_examples
                if random:
                    all_examples = np.random.permutation(all_examples)
                res = self.generate_one(seed, label, examples=all_examples, sample_text=sample_text)
            else:
                res = self.generate_one(seed, label, sample_text=sample_text)
            if res is not None:
                gen_res.append(res)
        self.save_res(gen_res)

    def generate_all(self, label, sort=True, examples={}, use_demo=1):
        gen_res = []
        # examples = examples[label]
        for seed in tqdm(range(self.args.num_gen)):
            if self.sampled_texts is None:
                sample_text = None
            else:
                sample_text = self.sampled_texts[seed]
            # if len(examples) > 0:
            #     choice_idx = np.random.choice(len(examples), use_demo, replace=False)
            #     input_examples = [examples[i] for i in choice_idx]
            #     res = self.generate_one(seed, label, examples=input_examples, sample_text=sample_text)
            # else:
            res = self.generate_one(seed, label, sample_text=sample_text)
            if res is not None:
                gen_res.append(res)
        self.save_res(gen_res, sort)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--pretrain_corpus_dir', default=None,)
    parser.add_argument('--task', default='mnli',)
    parser.add_argument('--label', default='entailment',)
    parser.add_argument('--model_type', default='ctrl',)
    parser.add_argument('--model_name_or_path', default='ctrl',)
    parser.add_argument('--temperature', default='0.2')
    parser.add_argument('--repetition_reward', default=None, type=float)
    parser.add_argument('--repetition_penalty', default=None, type=float)
    parser.add_argument('--p', default=1.0, type=float)
    parser.add_argument('--k', default=10, type=int)
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--no_cuda', default=False,)
    parser.add_argument('--fp16', default=False,)
    parser.add_argument('--num_gen', default=10, type=int)
    parser.add_argument('--max_len', default=60, type=int)
    parser.add_argument('--save_dir', default='temp_gen')
    parser.add_argument('--print_res', action='store_true')
    args = parser.parse_args()
    print(args)
    args.task = args.task.lower()
    args.temperature = eval(args.temperature)
    args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()

    logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}")

    # Generate texts for all labels
    if args.label == "all":
        for label in prompt_mapping[args.task]:
            args.label = label
            generator = SuperGenGenerator(args)
            generator.generate_all()
            # If texts of all labels are generated in one pass 
            # (by varying temperatures or extracting from the same generated text),
            # no need to redo generation for each label
            if args.task in vary_temperature or args.task in extract_remaining:
                break
    else:
        generator = SuperGenGenerator(args)
        generator.generate_all()


if __name__ == "__main__":
    main()
    