import argparse
import json
import logging
import os
from collections import defaultdict
from copy import deepcopy

import numpy as np
import torch
from tqdm import tqdm

from scipy import optimize
from analysis.loss_landscape import (
    calculate_loss_contours,
    calculate_model_interpolation,
)
from analysis.sharpness import sharpness, flatten_parameters
from analysis.utils import flatten_gradients

from transformers import AutoConfig, BertTokenizer
from torch.utils.data.dataloader import DataLoader
from transformers.data.data_collator import DataCollator, DefaultDataCollator

from lang_exps.main_text import (
    ModelArguments,
    task_sequences,
    process_args,
    get_datasets,
    setup_logging,
)
from lang_exps.data.processors.data import tasks_num_labels
from lang_exps.data.util import never_split
from lang_exps.models.distilbert import DistilBertForSequenceClassification
from lang_exps.models.bert import BertForSequenceClassification
from lang_exps.models.roberta import RobertaForSequenceClassification

logger = logging.getLogger(__name__)


def create_eval_fn(task, calculate_gradient=False):

    # Implementation specific to DistilBERT
    def eval_fn(model, dataloader, device):
        model.eval()
        total_loss = 0
        num_correct = 0
        model.zero_grad()
        num_examples = len(dataloader.dataset)
        torch.set_grad_enabled(calculate_gradient)
        for inputs in dataloader:
            for k, v in inputs.items():
                inputs[k] = v.to(device)

            outputs = model(**inputs, task=task, reduction="sum", device=None)
            loss = outputs[0] / num_examples
            logits = outputs[1]
            preds = torch.argmax(logits, dim=1)
            y = inputs["labels"]

            num_correct += (preds == y).sum().item()
            if calculate_gradient:
                loss.backward()
            total_loss += loss.item()
        accuracy = num_correct / num_examples
        metrics = {"loss": total_loss, "accuracy": accuracy}
        if calculate_gradient:
            gradients = flatten_gradients(model)
            metrics["gradients"] = gradients

        return metrics

    return eval_fn


def run_contour(data_args, model_args, training_args):

    # Implementation specific to DistilBERT
    training_args = setup_logging(training_args)

    logger.info(f"Getting data for loss landscape analysis!")
    model_name_or_path = training_args.output_dir

    tokenizer = BertTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_name_or_path,
        cache_dir=model_args.cache_dir,
        never_split=never_split,
    )

    # Get training and evaluation datasets
    train_datasets, val_datasets, test_datasets = get_datasets(
        data_args=data_args, training_args=training_args, tokenizer=tokenizer
    )

    n_tasks = len(task_sequences[data_args.task_name])
    tasks = task_sequences[data_args.task_name]
    num_labels = [tasks_num_labels[task] for task in tasks]

    # Initialize model config
    config = AutoConfig.from_pretrained(
        model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    logger.info(f"Task seq: {','.join(tasks)}")

    task1 = task_sequences[data_args.task_name][model_args.analysis_start_task_idx]
    task2 = task_sequences[data_args.task_name][model_args.analysis_start_task_idx + 1]
    task3 = task_sequences[data_args.task_name][model_args.analysis_start_task_idx + 2]

    model1_identifier = f"{data_args.task_name}-after-{task1}"
    model2_identifier = f"{data_args.task_name}-after-{task2}"
    model3_identifier = f"{data_args.task_name}-after-{task3}"

    model1_dir = os.path.join(training_args.output_dir, model1_identifier)
    model2_dir = os.path.join(training_args.output_dir, model2_identifier)
    model3_dir = os.path.join(training_args.output_dir, model3_identifier)
    # Initialize model for pruning and evaluation
    if model_args.model_type == "distilbert":
        logger.info(f"Initializing model1 from {model1_dir}")
        model1 = DistilBertForSequenceClassification.from_pretrained(
            model1_dir,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )
        logger.info(f"Initializing model2 from {model2_dir}")
        model2 = DistilBertForSequenceClassification.from_pretrained(
            model2_dir,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )

        logger.info(f"Initializing model3 from {model3_dir}")
        model3 = DistilBertForSequenceClassification.from_pretrained(
            model3_dir,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )
    else:
        logger.info(f"Initializing model1 from {model1_dir}")
        model1 = BertForSequenceClassification.from_pretrained(
            model1_dir,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )
        logger.info(f"Initializing model2 from {model2_dir}")
        model2 = BertForSequenceClassification.from_pretrained(
            model2_dir,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )

        logger.info(f"Initializing model3 from {model3_dir}")
        model3 = BertForSequenceClassification.from_pretrained(
            model3_dir,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )

    model1 = model1.to(device=training_args.device)
    model2 = model2.to(device=training_args.device)
    model3 = model3.to(device=training_args.device)

    split = model_args.analysis_split

    if split == "validation":
        dataset_to_load = val_datasets[task1]
    else:
        dataset_to_load = test_datasets[task1]
    logger.info(f"No. of examples for evaluation: {split}-{len(dataset_to_load)}")

    data_collator = DefaultDataCollator()
    dataloader = DataLoader(
        dataset_to_load,
        sampler=None,
        batch_size=training_args.eval_batch_size,
        shuffle=False,
        collate_fn=data_collator.collate_batch,
    )

    logger.info(f"Running evaluation on {task1} split {split} model1")
    metrics = create_eval_fn(task=task1)(
        model=model1, dataloader=dataloader, device=training_args.device
    )
    logger.info(f"Loss {metrics['loss']}, accuracy {metrics['accuracy']}")
    logger.info(f"Running evaluation on {task1} split {split} model2")
    metrics = create_eval_fn(task=task1)(
        model=model2, dataloader=dataloader, device=training_args.device
    )
    logger.info(f"Loss {metrics['loss']}, accuracy {metrics['accuracy']}")
    logger.info(f"Running evaluation on {task1} split {split} model3")
    metrics = create_eval_fn(task=task1)(
        model=model3, dataloader=dataloader, device=training_args.device
    )
    logger.info(f"Loss {metrics['loss']}, accuracy {metrics['accuracy']}")

    results = calculate_loss_contours(
        model1=model1,
        model2=model2,
        model3=model3,
        dataloader=dataloader,
        eval_fn=create_eval_fn(task=task1),
        device=training_args.device,
    )

    output_file = os.path.join(
        training_args.output_dir,
        f"contour_data_{split}_{model_args.analysis_start_task_idx}.json",
    )
    with open(output_file, "w") as f:
        json.dump(results, f)

    return results


def run_lmi(data_args, model_args, training_args):

    # Implementation specific to DistilBERT
    training_args = setup_logging(training_args)

    logger.info(f"Getting data for linear mode connectivity analysis!")
    model_name_or_path = training_args.output_dir

    tokenizer = BertTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_name_or_path,
        cache_dir=model_args.cache_dir,
        never_split=never_split,
    )

    # Get training and evaluation datasets
    train_datasets, val_datasets, test_datasets = get_datasets(
        data_args=data_args, training_args=training_args, tokenizer=tokenizer
    )

    n_tasks = len(task_sequences[data_args.task_name])
    tasks = task_sequences[data_args.task_name]
    num_labels = [tasks_num_labels[task] for task in tasks]

    # Initialize model config
    config = AutoConfig.from_pretrained(
        model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    task1 = task_sequences[data_args.task_name][model_args.analysis_start_task_idx]
    model1_identifier = f"{data_args.task_name}-after-{task1}"
    model1_dir = os.path.join(training_args.output_dir, model1_identifier)

    logger.info(f"Task seq: {','.join(tasks)}")
    logger.info(f"Task 1 - {task1}")
    logger.info(f"Model identifier 1 - {model1_identifier}")

    split = model_args.analysis_split

    if split == "validation":
        dataset_to_load = val_datasets[task1]
    else:
        dataset_to_load = test_datasets[task1]
    logger.info(f"No. of examples for evaluation: {split}-{len(dataset_to_load)}")

    data_collator = DefaultDataCollator()
    dataloader = DataLoader(
        dataset_to_load,
        sampler=None,
        batch_size=training_args.eval_batch_size,
        shuffle=False,
        collate_fn=data_collator.collate_batch,
    )

    results = {}
    for task_idx_1 in range(
        model_args.analysis_start_task_idx, model_args.analysis_start_task_idx + 1
    ):
        logging.info(f"Starting task {task_idx_1}")
        for task_idx_2 in range(task_idx_1 + 1, len(tasks)):
            logging.info(f"Calculating lmc from {task_idx_1}->{task_idx_2}")

            if model_args.model_type == "distilbert":
                logger.info(f"Initializing model1 from {model1_dir}")
                model1 = DistilBertForSequenceClassification.from_pretrained(
                    model1_dir,
                    from_tf=False,
                    config=config,
                    cache_dir=model_args.cache_dir,
                )
            else:
                logger.info(f"Initializing model1 from {model1_dir}")
                model1 = BertForSequenceClassification.from_pretrained(
                    model1_dir,
                    from_tf=False,
                    config=config,
                    cache_dir=model_args.cache_dir,
                )

            task2 = task_sequences[data_args.task_name][task_idx_2]
            model2_identifier = f"{data_args.task_name}-after-{task2}"
            model2_dir = os.path.join(training_args.output_dir, model2_identifier)

            if model_args.model_type == "distilbert":
                logger.info(f"Initializing model2 from {model2_dir}")
                model2 = DistilBertForSequenceClassification.from_pretrained(
                    model2_dir,
                    from_tf=False,
                    config=config,
                    cache_dir=model_args.cache_dir,
                )
            else:
                logger.info(f"Initializing model2 from {model2_dir}")
                model2 = BertForSequenceClassification.from_pretrained(
                    model2_dir,
                    from_tf=False,
                    config=config,
                    cache_dir=model_args.cache_dir,
                )

            losses, accuracies, ts = calculate_model_interpolation(
                model1=model1,
                model2=model2,
                dataloader=dataloader,
                eval_fn=create_eval_fn(task=task1),
                device=training_args.device,
            )
            res = {}
            res["losses"] = losses
            res["accuracies"] = accuracies
            res["ts"] = ts
            results[f"{task_idx_1}_to_{task_idx_2}"] = res

    output_file = os.path.join(
        training_args.output_dir,
        f"lmc_data_{split}_{model_args.analysis_start_task_idx}.json",
    )
    with open(output_file, "w") as f:
        json.dump(results, f)

    return results


def create_sharpness_fn(dataloader, task, device):

    full_loss_fn = create_eval_fn(task, calculate_gradient=True)

    def get_loss(model):
        return full_loss_fn(model, dataloader, device)

    return get_loss


def run_sharpness(data_args, model_args, training_args, p=0):

    # Implementation specific to DistilBERT
    training_args = setup_logging(training_args)

    logger.info(f"Evaluating sharpness metric for our analysis!")
    model_name_or_path = training_args.output_dir

    tokenizer = BertTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_name_or_path,
        cache_dir=model_args.cache_dir,
        never_split=never_split,
    )

    # Get training and evaluation datasets
    train_datasets, val_datasets, test_datasets = get_datasets(
        data_args=data_args, training_args=training_args, tokenizer=tokenizer
    )

    n_tasks = len(task_sequences[data_args.task_name])
    tasks = task_sequences[data_args.task_name]
    num_labels = [tasks_num_labels[task] for task in tasks]

    # Initialize model config
    config = AutoConfig.from_pretrained(
        model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    split = model_args.analysis_split

    tasks = task_sequences[data_args.task_name]
    logging.info(f"Starting task {tasks[0]}")
    model_identifier = f"{data_args.task_name}-after-{tasks[0]}"
    model_dir = os.path.join(training_args.output_dir, model_identifier)

    if model_args.model_type == "distilbert":
        logger.info(f"Initializing model from {model_dir}")
        model = DistilBertForSequenceClassification.from_pretrained(
            model_dir,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )
    else:
        logger.info(f"Initializing model from {model_dir}")
        model = BertForSequenceClassification.from_pretrained(
            model_dir,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )

    num_parameters = sum(
        param.numel() for param in model.parameters() if param.requires_grad
    )
    logger.info(f"Number of parameters - {num_parameters}")

    if p == 0:
        A = 1
    else:
        A = np.random.rand(num_parameters, p)
        logger.info("Sampled the A matrix!")
        A = A.astype(np.float32)
        A /= np.linalg.norm(A, axis=0, keepdims=True)
        logger.info("Done with the norm!")

    progress = tqdm(total=2 * len(tasks))
    results = defaultdict(dict)

    for task_idx, task in enumerate(task_sequences[data_args.task_name]):

        if split == "validation":
            dataset_to_load = val_datasets[task]
            logger.info(
                f"No. of examples for evaluation: {task}:{split}-{len(dataset_to_load)}"
            )
        else:
            dataset_to_load = test_datasets[task]
            logger.info(
                f"No. of examples for evaluation: {task}:{split}-{len(dataset_to_load)}"
            )
            # Randomly sample
            sample_size = 1280
            shuffled_indices = np.random.RandomState(training_args.seed).choice(
                len(dataset_to_load), sample_size, replace=False
            )
            dataset_to_load.features = [
                dataset_to_load.features[idx] for idx in shuffled_indices
            ]

            logger.info(
                f"No. of examples for evaluation (after sampling): {task}:{split}-{len(dataset_to_load)}"
            )

        data_collator = DefaultDataCollator()
        dataloader = DataLoader(
            dataset_to_load,
            sampler=None,
            batch_size=training_args.eval_batch_size,
            shuffle=False,
            collate_fn=data_collator.collate_batch,
        )

        logging.info(f"Starting task {task_idx}")
        model_identifier = f"{data_args.task_name}-after-{task}"
        model_dir = os.path.join(training_args.output_dir, model_identifier)

        if model_args.model_type == "distilbert":
            logger.info(f"Initializing model from {model_dir}")
            model = DistilBertForSequenceClassification.from_pretrained(
                model_dir,
                from_tf=False,
                config=config,
                cache_dir=model_args.cache_dir,
            )
        else:
            logger.info(f"Initializing model from {model_dir}")
            model = BertForSequenceClassification.from_pretrained(
                model_dir,
                from_tf=False,
                config=config,
                cache_dir=model_args.cache_dir,
            )

        model.to(training_args.device)

        if p != 0:
            logger.info("Computing bounds!")
            x = flatten_parameters(model)
            b, _, _, _ = np.linalg.lstsq(A, x)
            bound = np.abs(b) + 1

            logger.info("Done with bounds!")
        else:
            x = flatten_parameters(model)
            bound = np.abs(x) + 1

        for epsilon in [5e-4, 1e-4, 5e-5]:

            if p != 0:
                bounds = epsilon * bound
            else:
                bounds = epsilon * bound
            sharpness_value = sharpness(
                model=model,
                criterion_fn=create_sharpness_fn(
                    dataloader=dataloader, task=task, device=training_args.device
                ),
                A=A,
                epsilon=epsilon,
                p=p,
                bounds=optimize.Bounds(-bounds, bounds) if bounds is not None else None,
            )
            logging.info(
                f"Sharpness evaulation on model after task {task} on task {task} Epsilon {epsilon} sharpness: {sharpness_value}"
            )
            results[task_idx][epsilon] = sharpness_value

            output_file = os.path.join(
                training_args.output_dir, f"sharpness_{p}_{split}.json"
            )
            with open(output_file, "w") as f:
                json.dump(results, f)

            progress.update()

    output_file = os.path.join(training_args.output_dir, f"sharpness_{p}_{split}.json")
    with open(output_file, "w") as f:
        json.dump(results, f)

    return results


def main():

    data_args, model_args, training_args = process_args()

    if model_args.analysis == "contour":
        results = run_contour(
            data_args=data_args, model_args=model_args, training_args=training_args
        )

    elif model_args.analysis == "lmi":
        results = run_lmi(
            data_args=data_args, model_args=model_args, training_args=training_args
        )

    elif model_args.analysis == "sharpness":
        results = run_sharpness(
            data_args=data_args,
            model_args=model_args,
            training_args=training_args,
            p=model_args.p_dim,
        )
    else:
        raise ValueError(f"Analysis type {model_args.analysis} not supported")

    with open(model_args.output_file, "w") as f:
        json.dump(results, f)


if __name__ == "__main__":
    main()
