import json
import logging
import os
import argparse
from collections import defaultdict
import random
import numpy as np
import pickle

import wandb
import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from transformers import WEIGHTS_NAME, AutoTokenizer, BertConfig

logger = logging.getLogger(__name__)
# logging.disable(logging.WARNING) 

from trainer import Trainer
from options import setup_args
from squad_metrics import (SquadResult, compute_predictions_logits,
                           squad_evaluate)

# from models.modeling_bert import BertForExtractiveQA
from transformers.models.bert.modeling_bert import BertForQuestionAnswering

# from record_eval import evaluate as evaluate_on_record
from utils import (
    QAProcessor,
    convert_examples_to_features,
)

WEIGHTS_NAME = "pytorch_model.bin"

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def to_list(tensor):
    return tensor.detach().cpu().tolist()

def initialize_model(args):
    lm_ckpt = "bert-base-uncased"
    model_cls = BertForQuestionAnswering
    config_cls = BertConfig

    tokenizer = AutoTokenizer.from_pretrained(lm_ckpt)
    model = model_cls.from_pretrained(lm_ckpt)
    
    args.tokenizer = tokenizer
    return model, tokenizer

def run(args):
    set_seed(args.seed)
    args.device = 'cuda'

    model, tokenizer = initialize_model(args)
    model.to(args.device)

    train_dataloader, _, _, _ = load_examples(args, "train")

    num_train_steps_per_epoch = len(train_dataloader)
    num_train_steps = int(num_train_steps_per_epoch * args.num_train_epochs)

    best_dev_score = [-1]
    best_weights = [None]
    results = {}

    def step_callback(model, global_step):
        if global_step % num_train_steps_per_epoch == 0 and args.local_rank in (0, -1):
            epoch = int(global_step / num_train_steps_per_epoch - 1)

            dev_results = evaluate(args, model, fold="test")

            tqdm.write("dev: " + str(dev_results))
            results.update({f"dev_{k}_epoch{epoch}": v for k, v in dev_results.items()})
            if dev_results["f1"] > best_dev_score[0]:
                if hasattr(model, "module"):
                    best_weights[0] = {k: v.to("cpu").clone() for k, v in model.module.state_dict().items()}
                else:
                    best_weights[0] = {k: v.to("cpu").clone() for k, v in model.state_dict().items()}
                best_dev_score[0] = dev_results["f1"]
                results["best_epoch"] = epoch
            model.train()

    if not args.do_eval:
        trainer = Trainer(
            args,
            model=model,
            dataloader=train_dataloader,
            num_train_steps=num_train_steps,
            step_callback=step_callback,
        )
        trainer.train()

        print(results)

        logger.info("Saving the model checkpoint to %s", args.output_dir)
        torch.save(best_weights[0], os.path.join(args.output_dir, WEIGHTS_NAME))

    # Evaluate
    model, tokenizer = initialize_model(args)
    model.load_state_dict(torch.load(os.path.join(args.output_dir, WEIGHTS_NAME), map_location="cpu"))
    model.to(args.device)

    output_file = os.path.join(args.output_dir, "predictions.json")
    results = evaluate(args, model, fold="test", output_file=output_file)
    
    with open(os.path.join(args.output_dir, "results.json"), "w") as f:
        json.dump(results, f)
    
    torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
    print(f"SEED: {args.seed}")
    print(f"Checkpoint: {args.checkpoint}")
    print(results)
    return results

def evaluate(args, model, fold="dev", output_file=None):
    dataloader, examples, features, processor = load_examples(args, fold)
    
    tokenizer = args.tokenizer

    all_results = []
    for batch in tqdm(dataloader, desc="Eval"):
        model.eval()
        inputs = {k: v.to(args.device) for k, v in batch.items() if k != "feature_indices"}

        with torch.no_grad():
            if args.lm_type in ["roberta"]:
                del inputs["token_type_ids"]
            outputs = model(**inputs)
            start_logits, end_logits = outputs[0], outputs[1]
            outputs = (start_logits, end_logits)

        feature_indices = batch["feature_indices"]
        for i, feature_index in enumerate(feature_indices):
            eval_feature = features[feature_index.item()]
            unique_id = int(eval_feature.unique_id)

            output = [to_list(output[i]) for output in outputs]
            start_logits, end_logits = output
            result = SquadResult(unique_id, start_logits, end_logits)

            all_results.append(result)

    # Compute predictions
    output_prediction_file = os.path.join(
        args.output_dir, "predictions.json")
    output_nbest_file = os.path.join(
        args.output_dir, "nbest_predictions.json")
    output_null_log_odds_file = None

    predictions = compute_predictions_logits(
        examples,
        features,
        all_results,
        args.n_best_size,
        args.max_answer_length,
        args.do_lower_case,
        output_prediction_file,
        output_nbest_file,
        output_null_log_odds_file,
        args.verbose_logging,
        args.version_2_with_negative,
        args.null_score_diff_threshold,
        tokenizer,
    )

    # Compute the F1 and exact scores.
    results = squad_evaluate(examples, predictions)
    return results

def load_examples(args, fold):
    processor = QAProcessor(args)
    if fold == "train":
        examples = processor.get_train_examples(args.data_dir)
    elif fold == "dev":
        examples = processor.get_dev_examples(args.data_dir)
    else:
        examples = processor.get_test_examples(args.data_dir)

    logger.info("Creating features from the dataset...")
    if args.lm_type == "bert":
        pickle_name = "train_features_bert.pkl"

    if args.read_data or (fold == "dev" or fold == "test"):
        features = convert_examples_to_features(
            examples,
            args.tokenizer,
            args.max_seq_length,
            args.doc_stride,
            args.max_query_length,
            is_training=fold=="train"
        )
        if fold == "train":
            with open(os.path.join(args.pickle_folder, pickle_name), 'wb+') as f:
                pickle.dump(features, f)
    else:
        with open(os.path.join(args.pickle_folder, pickle_name), 'rb') as f:
            features = pickle.load(f)

    def collate_fn(batch):
        def create_padded_sequence(target, padding_value):
            if isinstance(target, str):
                tensors = [torch.tensor(getattr(o[1], target), dtype=torch.long) for o in batch]
            else:
                tensors = [torch.tensor(o, dtype=torch.long) for o in target]
            return pad_sequence(tensors, batch_first=True, padding_value=padding_value)

        ret = dict(
            input_ids=create_padded_sequence("input_ids", args.tokenizer.pad_token_id),
            attention_mask=create_padded_sequence("input_mask", 0),
            token_type_ids=create_padded_sequence("segment_ids", 0),
        )
        if fold == "train":
            ret["start_positions"] = torch.stack([torch.tensor(getattr(o[1], "start_position"), dtype=torch.long) for o in batch])
            ret["end_positions"] = torch.stack([torch.tensor(getattr(o[1], "end_position"), dtype=torch.long) for o in batch])
        else:
            ret["feature_indices"] = torch.tensor([o[0] for o in batch], dtype=torch.long)
        return ret

    if fold == "train":
        if args.local_rank == -1:
            sampler = RandomSampler(features)
        else:
            sampler = DistributedSampler(features)
        dataloader = DataLoader(
            list(enumerate(features)), sampler=sampler, batch_size=args.train_batch_size, collate_fn=collate_fn
        )
    else:
        dataloader = DataLoader(list(enumerate(features)), batch_size=args.eval_batch_size, collate_fn=collate_fn)

    return dataloader, examples, features, processor

if __name__ == "__main__":
    args = setup_args()
    run(args)