
import collections
import inspect
import math
import os
import re
import shutil
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from torch.optim.lr_scheduler import LambdaLR
import math
from scipy.sparse.linalg import LinearOperator, eigsh

from torch.nn.utils import parameters_to_vector, vector_to_parameters

import transformers
from transformers.file_utils import is_datasets_available, is_in_notebook, is_torch_tpu_available
from transformers.integrations import (
    is_comet_available,
    is_optuna_available,
    is_ray_available,
    is_tensorboard_available,
    is_wandb_available,
)
from transformers.optimization import AdamW, get_linear_schedule_with_warmup, get_scheduler

from transformers.trainer_callback import (
    DefaultFlowCallback,
    ProgressCallback,
)
from transformers.trainer_utils import (
    default_compute_objective,
)
from transformers.training_args import TrainingArguments
from transformers.utils import logging
from transformers.trainer_utils import TrainOutput

from tqdm import tqdm, trange
from torch.optim import SGD

from src.linearhead_trainer import LinearHeadTrainer
from transformers.trainer_callback import TrainerState

import copy

_use_native_amp = False
_use_apex = False

DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

if is_in_notebook():
    from transformers.utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback

# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
    from transformers.file_utils import is_apex_available

    if is_apex_available():
        from apex import amp
    _use_apex = True
else:
    _use_native_amp = True
    from torch.cuda.amp import autocast

if version.parse(torch.__version__) < version.parse("1.2"):
    _use_ddp_no_sync = False
else:
    _use_ddp_no_sync = True

if is_datasets_available():
    import datasets

if is_torch_tpu_available():
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

if is_tensorboard_available():
    from transformers.integrations import TensorBoardCallback

    DEFAULT_CALLBACKS.append(TensorBoardCallback)


if is_wandb_available():
    from transformers.integrations import WandbCallback

    DEFAULT_CALLBACKS.append(WandbCallback)

if is_comet_available():
    from transformers.integrations import CometCallback

    DEFAULT_CALLBACKS.append(CometCallback)

if is_optuna_available():
    import optuna

if is_ray_available():
    from ray import tune

logger = logging.get_logger(__name__)


class HessianTrainer(LinearHeadTrainer):
    def train(self, model_path=None, dev_objective=None):
        # Data loading.
        train_dataloader = self.get_train_dataloader()
        num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
        if num_update_steps_per_epoch == 0:
            num_update_steps_per_epoch = 1
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                self.args.max_steps % num_update_steps_per_epoch > 0
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        self.create_optimizer_and_scheduler(num_training_steps=t_total)
        optimizer = self.optimizer
        scheduler = self.lr_scheduler

    # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model

        if self.args.fp16 and _use_apex:
            if not transformers.is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)

        # Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        # Train
        if transformers.is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Computing the Hessian *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))

        self.state = TrainerState()
        self.state.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        if self.args.gradient_checkpointing:
            model.gradient_checkpointing_enable()

        # prepare for tr computation
        p = len(parameters_to_vector(model.parameters()))
        self.n = len(self.train_dataset)
        tr = 0.0

        # find diffable params
        _, inputs = next(enumerate(train_dataloader))
        with self.compute_loss_context_manager():
            inputs = self._prepare_inputs(inputs)
            loss, logits = model(**inputs)
        loss.backward()
        self.diffable_params = [p for p in model.parameters() if p.grad is not None]
        self.p = np.sum([p.data.numel() for p in self.diffable_params])
        model.zero_grad()
        self.model = model #.cpu()

        logger.info("  Num parameters = %d", self.p)
        model.zero_grad()
        for p in model.parameters():
            p.grad = None
        tr_estimate, num_tr_estimates = self.hutchinson()
        model.zero_grad()
        for p in model.parameters():
            p.grad = None
        evals, evecs = self.lanczos(self.hvp, self.p, neigs=16)
        
        logger.info(f" Top eigenvalues: {evals}")
        logger.info(f"  Tr(H) = {tr_estimate}")
        logger.info(f"  Operator norm of H = {evals[0]}")
        logger.info(f"  Effective rank of H = {tr_estimate / evals[0]}")

        return {'trace': tr_estimate,
                'op_norm': evals[0],
                'effective_rank': tr_estimate / evals[0],
                'num_tr_estimates': num_tr_estimates}


    def hutchinson(self):
        train_dataloader = self.get_train_dataloader()
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
        model = self.model
        num_estimates = 0
        values = [] 
        trace=0.0

        loss = 0.0 
        for step, inputs in tqdm(enumerate(epoch_iterator)):
            # sample batch
            inputs = self._prepare_inputs(inputs)
            #inputs = {k: v.cpu() for k,v in inputs.items()}
            _loss, logits = model(**inputs)
            #_loss *= self.args.train_batch_size
            loss += _loss

        grads = torch.autograd.grad(loss, inputs=self.diffable_params, create_graph=True)
        for _ in range(self.args.num_hvp_vecs):
            v = [torch.randn_like(p, device='cuda') for p in self.diffable_params]
            gv = sum([torch.sum(x * y) for (x, y) in zip(grads, v)])
            Hv = torch.autograd.grad(gv, self.diffable_params, retain_graph=True)
            values.append(sum([torch.sum(x * y) for (x, y) in zip(Hv, v)]).item())
        
        std = np.std(values)
        tr_est = np.mean(values)
        if std < self.args.mc_tol * tr_est:
            return np.mean(values), num_estimates

        return np.mean(values), num_estimates

    def hvp(self, z=None):
        train_dataloader = self.get_train_dataloader()
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
        model = self.model
        hvp = torch.zeros(self.p, dtype=torch.float, device='cuda')#, device='cpu')

        for step, inputs in tqdm(enumerate(epoch_iterator)):
            # sample batch
            inputs = self._prepare_inputs(inputs)
            #inputs = {k: v.cpu() for k,v in inputs.items()}
            loss, logits = model(**inputs)
            loss *= self.args.train_batch_size / (self.n * 1.0)

            grads = torch.autograd.grad(loss, inputs=self.diffable_params, create_graph=True)
            dot = parameters_to_vector(grads).mul(z).sum() # HVP
            grads = [g.contiguous() for g in torch.autograd.grad(dot, self.diffable_params, retain_graph=True)]
            hvp += parameters_to_vector(grads)

        return hvp

    def lanczos(self, matrix_vector, dim: int, neigs: int):
        def mv(vec):
            vec = torch.tensor(vec, dtype=torch.float).cuda()
            return matrix_vector(vec).detach().cpu()

        operator = LinearOperator((dim, dim), matvec=mv)
        evals, evecs = eigsh(operator, neigs, maxiter=16)
        return evals[::-1], np.flip(evecs, -1)