import json
import logging
import math
from math import log
import os
import argparse
from collections import defaultdict
import random
import numpy as np
import pickle

import wandb
import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from torch.nn.utils.rnn import pad_sequence
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

from tqdm import tqdm
from evaluate import evaluate_f1, open_reference_file

logger = logging.getLogger(__name__)
# logging.disable(logging.WARNING)

from trainer import Trainer
from options import setup_args
from utils import (
    Dialprocessor,
    load_raw_dataset,
    Profiler
)
from transformers import WEIGHTS_NAME, AutoTokenizer, BertConfig
from metrics import sequence_loss, bleu_metric, f1_metric, distinct_metric
from rouge import Rouge
from evaluation.evaluate_wrapper import run_KQA


os.environ["TOKENIZERS_PARALLELISM"] = "false"
WEIGHTS_NAME = "pytorch_model.bin"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def to_list(tensor):
    return tensor.detach().cpu().tolist()

def initialize_model(args, entity_embeddings=None):
    if args.lm_type == 't5':
        tokenizer = AutoTokenizer.from_pretrained("t5-small")
        args.tokenizer = tokenizer
    else:
        raise NotImplementedError

    if args.lm_type == 't5':
        from models.models.modeling_t5 import T5ForKnowledgeAugmentedGeneration
        model = T5ForKnowledgeAugmentedGeneration(args, entity_embeddings)
    else:
        raise NotImplementedError
 
    return model, tokenizer

def run(rank, args):
    args.local_rank = rank
    if args.world_size > 1:
        print(f"Use GPU: {rank} for training")
        env_dict = {
                key: os.environ[key]
                for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
        }
        print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
        dist.init_process_group(backend='nccl')
        torch.cuda.set_device(rank)

    set_seed(args.seed)

    if args.local_rank == -1:
        args.device = 0
    else:
        args.device = rank

    entity_embeddings, wikidata_to_memory_map = load_entity_embeddings_memory(args)
    model, tokenizer = initialize_model(args, entity_embeddings)
    model.to(args.device)

    train_dataloader, _, _ = load_examples(args, "train")

    num_train_steps_per_epoch = len(train_dataloader)
    num_train_steps = int(num_train_steps_per_epoch * args.num_train_epochs)

    best_dev_score = [0.0]
    best_weights = [None]
    results = {}
    torch.save(args, os.path.join(args.output_dir, "training_args.bin")) # Save args beforehand

    def step_callback(model, global_step):
        if global_step % (num_train_steps_per_epoch * args.eval_frequency) == 0 and args.local_rank in [0, -1]:
            epoch = int(global_step / num_train_steps_per_epoch - 1)

            dev_results = evaluate(args, model, fold="dev", global_step=global_step)

            tqdm.write("dev: " + str(dev_results))
            results.update({f"dev_epoch{epoch}": dev_results})
            if dev_results["kqa-f1"] > best_dev_score[0]:
                if hasattr(model, "module"):
                    best_weights[0] = {k: v.to("cpu").clone() for k, v in model.module.state_dict().items()}
                else:
                    best_weights[0] = {k: v.to("cpu").clone() for k, v in model.state_dict().items()}
                best_dev_score[0] = dev_results["kqa-f1"]
                results["best_epoch"] = epoch
                # Intermediate save
                logger.info("Saving the model checkpoint to %s", args.output_dir)
                torch.save(best_weights[0], os.path.join(args.output_dir, WEIGHTS_NAME))
            model.train()
    trainer_cls = Trainer
    trainer = trainer_cls(
        args,
        model=model,
        dataloader=train_dataloader,
        num_train_steps=num_train_steps,
        step_callback=step_callback,
    )
    trainer.train()

    if args.local_rank in [0, -1]:
        print(results)
        logger.info("Saving the model checkpoint to %s", args.output_dir)
        torch.save(best_weights[0], os.path.join(args.output_dir, WEIGHTS_NAME))

        # Evaluate
        if 'graph' in args.knowledge:
            model, tokenizer = initialize_model(args, entity_embeddings)
        else:
            model, tokenizer = initialize_model(args, None)
        model.load_state_dict(torch.load(os.path.join(args.output_dir, WEIGHTS_NAME), map_location="cpu"))
        model.to(args.device)

        results = evaluate(args, model, fold="test")

        with open(os.path.join(args.output_dir, "results.json"), "w") as f:
            json.dump(results, f)

        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
        print(f"SEED: {args.seed}")
        print(results)
    return results

def evaluate(args, model, fold="dev", global_step=-1, visualize=False):
    dataloader, features, processor = load_examples(args, fold)
    dataset = load_raw_dataset(args, fold)

    tokenizer = args.tokenizer

    os.makedirs(os.path.join(args.output_dir, "candidates"), exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, "profiles"), exist_ok=True)

    if global_step > 0:
        pred_file = os.path.join(args.output_dir, "candidates", "{}_candidate_step{}.txt".format(fold,global_step))
        profile_file = os.path.join(args.output_dir, "profiles", "{}_profile_step{}.txt".format(fold,global_step))
    else:
        pred_file = os.path.join(args.output_dir, "{}_candidate.txt".format(fold))
        profile_file = os.path.join(args.output_dir, "profiles", "{}_profile.txt".format(fold))

    ref_file = os.path.join(args.output_dir, "{}_reference.txt".format(fold))
    pred_fw = open(pred_file, "w")
    ref_fw = open(ref_file, "w")
    profile_fw = open(profile_file, "w")
    dataset_ptr = 0
    profiler = Profiler(args)

    all_results = []
    n_token, eval_loss = 0, 0.0 # ppl
    test_hyp, test_ref = [], []
    rouge = Rouge()

    model.eval()

    if hasattr(model, "module"):
        model.module.knowledge_sampler.print_average_len()
    else:
        model.knowledge_sampler.print_average_len()

    for batch in tqdm(dataloader, desc="Eval"):
        gen_inputs = {k: v.to(args.device) for k, v in batch.items() \
            if k in ['input_ids','attention_mask']}
        recon_inputs = {k: v.to(args.device) for k, v in batch.items() \
            if k != 'labels'}
        labels = batch['labels'].to(args.device)

        gen_inputs["max_length"] = 128
        gen_inputs["num_beams"] = 1
        gen_inputs["length_penalty"] = 1.0
        gen_inputs["repetition_penalty"] = 1
        gen_inputs["early_stopping"] = True
        gen_inputs["use_cache"] = True
        gen_inputs["do_sample"] = False
        gen_inputs["top_p"] = 0.95
        gen_inputs["top_k"] = 50
        gen_inputs["return_dict_in_generate"] = True
        sampler_inputs = {k: v.to(args.device) for k, v in batch.items() if k in ["mention_positions",
                                                                                    "nodes",
                                                                                    "edge_index",
                                                                                    "edge_attr",
                                                                                    "graph_batch",
                                                                                    "local_indicator"]}
        sampler_inputs["k"] = 1

        if hasattr(model, "module"):
            outputs = model.module.knowledge_sampler(
                gen_inputs["input_ids"],
                gen_inputs["attention_mask"],
                **sampler_inputs,
            )
        else:
            outputs = model.knowledge_sampler(
                gen_inputs["input_ids"],
                gen_inputs["attention_mask"],
                **sampler_inputs
            )
        input_ids = outputs[0]
        attention_mask = outputs[1]
        probs = outputs[2]
        scores = outputs[3]
        graph_inputs = outputs[4]
        gen_inputs["graph_inputs"] = graph_inputs

        gen_inputs["input_ids"] = input_ids # Overwrite
        gen_inputs["attention_mask"] = attention_mask # Overwrite

        with torch.no_grad():
            if hasattr(model, "module"):
                outputs = model.module.generator.generate(**gen_inputs)
            else:
                outputs = model.generator.generate(**gen_inputs)

            logits = model(**recon_inputs)[0]
            if 'graph' in args.knowledge:
                loss = sequence_loss(logits, labels, 0)
            else:
                loss = sequence_loss(logits, labels, tokenizer.pad_token_id)
            #print(logits.size(), labels.size(), loss.size())
            n_token += loss.size(0)
            eval_loss += loss.sum().item()

        batch_size = gen_inputs["input_ids"].size(0)
        for i in range(batch_size):
            pred_response = outputs.sequences[i].cpu()
            pred_response_token = tokenizer.decode(pred_response,
                                        skip_special_tokens=True,
                                        clean_up_tokenization_spaces=False)

            # Avoid -100
            labels[i][labels[i] == 0] = 0
            label_token = tokenizer.decode(labels[i].cpu(),
                                        skip_special_tokens=True,
                                        clean_up_tokenization_spaces=False)

            test_hyp.append(pred_response_token)
            test_ref.append(label_token)
            pred_fw.write(pred_response_token.strip() + "\n")
            pred_fw.flush()
            ref_fw.write(label_token.strip() + "\n")
            ref_fw.flush()

            profiler.write_profile(profile_fw, 
                                   dataset[dataset_ptr], 
                                   input_ids[i], 
                                   probs[i], 
                                   pred_response_token,
                                   graph_inputs,
                                   i # number of batch
                                   )
            dataset_ptr += 1
    
    pred_fw.close()
    ref_fw.close()
    profile_fw.close()

    mean_loss = eval_loss / n_token
    f1 = f1_metric(test_hyp, test_ref)
    kqa_em, kqa_f1 = 0.0, f1
    # kqa_em, kqa_f1 = run_KQA(pred_file, fold)
    b1, b2, b3, b4 = bleu_metric(test_hyp, test_ref)
    d1, d2 = distinct_metric(test_hyp)
    for i in range(len(test_hyp)):
        if len(test_hyp[i].strip()) == 0:
            test_hyp[i] = "dialogue:"
    rouge_score = rouge.get_scores(hyps=test_hyp, refs=test_ref, avg=True)
    
    results = {
        'ppl': math.exp(mean_loss),
        'bleu-1': b1, 'bleu-2': b2, 'bleu-3': b3, 'bleu-4': b4,
        'distinct-1': d1, 'distinct-2': d2,
        'f1': f1,
        'kqa-em': kqa_em, 'kqa-f1': kqa_f1, 
    }
    results.update(rouge_score)
    return results

def load_entity_embeddings_memory(args):
    """ Below are used if we use the pre-computed entity embeddings """
    memory_path = os.path.join(args.data_dir, "entity_codebook.pkl")
    label_path = os.path.join(args.data_dir, "relation_codebook.pkl")
    with open(memory_path, 'rb') as f:
        entity_embeddings_memory = pickle.load(f)

    with open(label_path, 'rb') as f:
        label_memory = pickle.load(f)
    label_map = dict()
    for idx, (key, value) in enumerate(label_memory.items()):
        label_map[value] = idx
    args.label_map = label_map

    wikidata_to_memory_map = dict()
    for idx, (key, value) in enumerate(entity_embeddings_memory.items()):
        wikidata_to_memory_map[value] = idx + 1

    args.wikidata_to_memory_map = wikidata_to_memory_map

    entity_embeddings = torch.zeros(len(wikidata_to_memory_map) + 1, args.entity_embed_size)
    args.initialize_embedding = True
    print(f"The number of entities: {entity_embeddings.shape[0]}")
    return entity_embeddings, wikidata_to_memory_map

def load_examples(args, fold):
    if 'graph' in args.knowledge:
        wikidata_to_memory_map = args.wikidata_to_memory_map
    else:
        wikidata_to_memory_map = None

    processor = Dialprocessor(args)
    if fold == "train":
        features = processor.get_train_examples(args.data_dir)
    elif fold == "dev":
        features = processor.get_dev_examples(args.data_dir)
    elif fold == "test":
        features = processor.get_test_examples(args.data_dir)
    elif fold == "toy":
        features = processor.get_toy_examples(args.data_dir)

    def collate_fn(batch):
        def create_padded_sequence(target, padding_value):
            if isinstance(target, str):
                tensors = [torch.tensor(getattr(o[1], target), dtype=torch.long) for o in batch]
            elif isinstance(target, tuple):
                tensors = target
            else:
                tensors = [torch.tensor(o, dtype=torch.long) for o in target]
            return pad_sequence(tensors, batch_first=True, padding_value=padding_value)

        def retrieve(key):
            if key in wikidata_to_memory_map.keys():
                return wikidata_to_memory_map[key]
            else:
                return 0
        
        user_ids, wizard_ids, kgs, gold_kgs = zip(*batch)

        user_ids = create_padded_sequence(user_ids, 0)
        wizard_ids = create_padded_sequence(wizard_ids, 0)
        src_wizard_ids = wizard_ids[:, :-1] # The "input" target sentence with <s> token (for teacher forcing input)
        trg_wizard_ids = wizard_ids[:, 1:] # The "label" target sentence without <s> token (for labels)
        # ignore index is -100 for T5
        trg_wizard_ids = trg_wizard_ids.masked_fill(trg_wizard_ids == 0, 0)

        # 0 for to be masked
        enc_mask = torch.sign(user_ids)
        dec_mask = torch.sign(src_wizard_ids)
        # since bos id is 0, change it into 1 (should be attended)
        dec_mask[:, 0] = 1

        """ convert batch to torch_geometric batch type """
        batch_nodes = []
        batch_edge_index = []
        batch_edge_attr = []
        batch_edge_label = []
        graph_batch = []
        batch_local_indicator = []
        for batch_idx, item in enumerate(kgs):
            nodes = [retrieve(node) for node in item.wikidata_ids]
            edge_index = [[len(graph_batch) + edge[0], len(graph_batch) + edge[1]] for edge in item.edge_index]

            graph_batch += [batch_idx] * len(nodes)
            batch_nodes += nodes
            batch_edge_index += edge_index
            batch_edge_attr += item.edge_attr 
            batch_edge_label += item.label
            batch_local_indicator += item.local_indicator
        assert len(batch_edge_index) == len(batch_edge_attr)
        batch_ent_pos = [kg.ent_pos for kg in kgs]

        ret = dict(
            input_ids=user_ids,
            attention_mask=enc_mask,
            decoder_input_ids=src_wizard_ids,
            decoder_attention_mask=dec_mask,
            labels=trg_wizard_ids,
            mention_positions=create_padded_sequence(batch_ent_pos, -1),
            nodes=torch.tensor(batch_nodes, dtype=torch.long),
            edge_index=torch.tensor(batch_edge_index, dtype=torch.long).t().reshape(2, -1),
            edge_attr=torch.tensor(batch_edge_attr, dtype=torch.long),
            edge_labels=torch.tensor(batch_edge_label, dtype=torch.long),
            graph_batch=torch.tensor(graph_batch, dtype=torch.long),
            local_indicator=torch.tensor(batch_local_indicator, dtype=torch.long),
        )
        if None not in gold_kgs:
            batch_gold_nodes = []
            batch_gold_edge_index = []
            batch_gold_edge_attr = []
            gold_graph_batch = []
            for batch_idx, item in enumerate(gold_kgs):
                nodes = [retrieve(node) for node in item.wikidata_ids]
                edge_index = [[len(gold_graph_batch) + edge[0], len(gold_graph_batch) + edge[1]] for edge in item.edge_index]

                gold_graph_batch += [batch_idx] * len(nodes)
                batch_gold_nodes += nodes
                batch_gold_edge_index += edge_index
                batch_gold_edge_attr += item.edge_attr # Mapping to memory is done in utils.py
            ret["gold_nodes"] = torch.tensor(batch_gold_nodes, dtype=torch.long)
            ret["gold_edge_index"] = torch.tensor(batch_gold_edge_index, dtype=torch.long).t().reshape(2, -1)
            ret["gold_edge_attr"] = torch.tensor(batch_gold_edge_attr, dtype=torch.long)
            ret["gold_graph_batch"] = torch.tensor(gold_graph_batch, dtype=torch.long)
        return ret

    if fold == "train":
        if args.local_rank == -1:
            sampler = RandomSampler(features)
        else:
            sampler = DistributedSampler(features)
        dataloader = DataLoader(
            features, 
            sampler=sampler, 
            batch_size=args.train_batch_size, 
            collate_fn=collate_fn,
            num_workers=16,
        )
    else:
        dataloader = DataLoader(features, 
            batch_size=args.eval_batch_size, 
            collate_fn=collate_fn,
            num_workers=16,
        )

    return dataloader, features, processor

if __name__ == "__main__":
    args = setup_args()
    ngpus_per_node = torch.cuda.device_count()
    if ngpus_per_node == 1:
        run(-1, args)
    else:
        args.world_size = ngpus_per_node * args.world_size
        args.local_rank = int(os.environ["LOCAL_RANK"])
        run(args.local_rank, args)