import json
import logging
import os
import argparse
from collections import defaultdict
import random
import numpy as np
import pickle

import wandb
import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from transformers import WEIGHTS_NAME, AutoTokenizer, BertConfig

logger = logging.getLogger(__name__)
# logging.disable(logging.WARNING) 

from trainer import Trainer
from options import setup_args

# from models.modeling_bert import BertForExtractiveQA
from transformers.models.bert.modeling_bert import BertForQuestionAnswering

if __name__ == "__main__":
    from squad_metrics import (SquadResult, compute_predictions_logits,
                            squad_evaluate)
    # from record_eval import evaluate as evaluate_on_record
    from utils import (
        QAProcessorForTest,
        convert_examples_to_features,
        save_dataset
    )
else:
    from .squad_metrics import (SquadResult, compute_predictions_logits,
                            squad_evaluate)
    # from record_eval import evaluate as evaluate_on_record
    from .utils import (
        QAProcessorForTest,
        convert_examples_to_features,
        save_dataset
    )

WEIGHTS_NAME = "pytorch_model.bin"

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def to_list(tensor):
    return tensor.detach().cpu().tolist()

def initialize_model(args):
    lm_ckpt = "bert-base-uncased"
    model_cls = BertForQuestionAnswering
    config_cls = BertConfig
    tokenizer_path = lm_ckpt
    config_path = lm_ckpt

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    if args.checkpoint is None:
        model = model_cls.from_pretrained(lm_ckpt)
    else:
        config = config_cls.from_pretrained(config_path)
        model = model_cls(config)
    
    args.tokenizer = tokenizer
    return model, tokenizer

def run(args):
    set_seed(args.seed)
    args.device = 'cuda'

    # Evaluate
    model, tokenizer = initialize_model(args)
    model.load_state_dict(torch.load(os.path.join(args.checkpoint, WEIGHTS_NAME), map_location="cpu"))
    model.to(args.device)

    output_file = os.path.join(args.output_dir, "predictions.json")
    if args.evaluate_on_test:
        results = evaluate(args, model, fold="test", output_file=output_file)
    else:
        results = evaluate(args, model, fold="dev", output_file=output_file)
    
    with open(os.path.join(args.output_dir, "results.json"), "w") as f:
        json.dump(results, f)

    print(f"SEED: {args.seed}")
    print(f"Checkpoint: {args.checkpoint}")
    print(results)
    return results

def evaluate(args, model, fold="dev", output_file=None):
    dataloader, examples, features, processor = load_examples(args, fold)
    
    save_dataset(examples, args.output_dir)

    tokenizer = args.tokenizer

    all_results = []
    for batch in tqdm(dataloader, desc="Eval"):
        model.eval()
        inputs = {k: v.to(args.device) for k, v in batch.items() if k != "feature_indices"}

        with torch.no_grad():
            if args.lm_type in ["roberta"]:
                del inputs["token_type_ids"]
            outputs = model(**inputs)
            start_logits, end_logits = outputs[0], outputs[1]
            outputs = (start_logits, end_logits)

        feature_indices = batch["feature_indices"]
        for i, feature_index in enumerate(feature_indices):
            eval_feature = features[feature_index.item()]
            unique_id = int(eval_feature.unique_id)

            output = [to_list(output[i]) for output in outputs]
            start_logits, end_logits = output
            result = SquadResult(unique_id, start_logits, end_logits)

            all_results.append(result)

    # Compute predictions
    output_prediction_file = os.path.join(
        args.output_dir, "predictions.json")
    output_nbest_file = os.path.join(
        args.output_dir, "nbest_predictions.json")
    output_null_log_odds_file = None

    predictions = compute_predictions_logits(
        examples,
        features,
        all_results,
        args.n_best_size,
        args.max_answer_length,
        args.do_lower_case,
        output_prediction_file,
        output_nbest_file,
        output_null_log_odds_file,
        args.verbose_logging,
        args.version_2_with_negative,
        args.null_score_diff_threshold,
        tokenizer,
    )

    # Compute the F1 and exact scores.
    results = squad_evaluate(examples, predictions)
    return results

def load_examples(args, fold):
    processor = QAProcessorForTest(args)
    if fold == "train":
        examples = processor.get_train_examples(args.data_dir, args.prediction_file, eval_target=args.eval_target)
    elif fold == "dev":
        examples = processor.get_dev_examples(args.data_dir, args.prediction_file, eval_target=args.eval_target)
    else:
        examples = processor.get_test_examples(args.data_dir, args.prediction_file, eval_target=args.eval_target)

    logger.info("Creating features from the dataset...")
    if args.lm_type == "bert":
        pickle_name = "train_features_bert.pkl"

    if args.read_data or (fold == "dev" or fold == "test"):
        features = convert_examples_to_features(
            examples,
            args.tokenizer,
            args.max_seq_length,
            args.doc_stride,
            args.max_query_length,
            is_training=fold=="train"
        )
        if fold == "train":
            with open(os.path.join(args.pickle_folder, pickle_name), 'wb+') as f:
                pickle.dump(features, f)
    else:
        with open(os.path.join(args.pickle_folder, pickle_name), 'rb') as f:
            features = pickle.load(f)

    def collate_fn(batch):
        def create_padded_sequence(target, padding_value):
            if isinstance(target, str):
                tensors = [torch.tensor(getattr(o[1], target), dtype=torch.long) for o in batch]
            else:
                tensors = [torch.tensor(o, dtype=torch.long) for o in target]
            return pad_sequence(tensors, batch_first=True, padding_value=padding_value)

        ret = dict(
            input_ids=create_padded_sequence("input_ids", args.tokenizer.pad_token_id),
            attention_mask=create_padded_sequence("input_mask", 0),
            token_type_ids=create_padded_sequence("segment_ids", 0),
        )
        if fold == "train":
            ret["start_positions"] = torch.stack([torch.tensor(getattr(o[1], "start_position"), dtype=torch.long) for o in batch])
            ret["end_positions"] = torch.stack([torch.tensor(getattr(o[1], "end_position"), dtype=torch.long) for o in batch])
        else:
            ret["feature_indices"] = torch.tensor([o[0] for o in batch], dtype=torch.long)
        return ret

    if fold == "train":
        if args.local_rank == -1:
            sampler = RandomSampler(features)
        else:
            sampler = DistributedSampler(features)
        dataloader = DataLoader(
            list(enumerate(features)), sampler=sampler, batch_size=args.train_batch_size, collate_fn=collate_fn
        )
    else:
        dataloader = DataLoader(list(enumerate(features)), batch_size=args.eval_batch_size, collate_fn=collate_fn)

    return dataloader, examples, features, processor

if __name__ == "__main__":
    import sys
    print(sys.argv)

    args = setup_args()
    run(args)