import os
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
import time
import torch
import collections
from packaging import version

from transformers import Trainer
from transformers import logging
from transformers.trainer_utils import (
    denumpify_detensorize,
    ShardedDDPOption,
    PREFIX_CHECKPOINT_DIR,
)
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, is_torch_tpu_available, is_sagemaker_mp_enabled
from transformers.trainer_pt_utils import (
    find_batch_size,
    nested_numpify,
    nested_concat,
    IterableDatasetShard,
    nested_detach,
    reissue_pt_warnings
)
from transformers.trainer_utils import EvalPrediction

from torch.utils.data.dataloader import DataLoader


if is_torch_tpu_available():
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

from tqdm.auto import tqdm
# Integrations must be imported before ML frameworks:
from transformers.integrations import (  # isort: split
    hp_params
)

from transformers import __version__
from torch import nn
from torch.utils.data.distributed import DistributedSampler


from transformers.trainer_callback import TrainerState

import math
import warnings

from transformers.file_utils import is_sagemaker_dp_enabled, is_apex_available
if is_sagemaker_dp_enabled():
    import smdistributed.dataparallel.torch.distributed as dist
    from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
else:
    import torch.distributed as dist

if is_apex_available():
    from apex import amp

if TYPE_CHECKING:
    import optuna

if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp

    from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat

from torch.optim import LBFGS
# from torch.optim import SdLBFGS
from pca import PCA
from kfac import KFAC
from pyhessian import hessian
import numpy as np
import random
import wandb
import time
from .util import clip_grad_norm_

TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"

logger = logging.get_logger(__name__)

class NewTextClassificationTrainer(Trainer):
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.
        Subclass and override to inject custom behavior.
        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.
                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)
        # print('\n\n########INPUTS#########\n\n')
        # print(inputs)

        # with self.autocast_smart_context_manager():
        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
        
        if isinstance(outputs, dict):
            logits = tuple(v for k, v in outputs.items() if k not in ["loss"])
        else:
            logits = outputs[1:]

        logits = nested_detach(logits)
        if len(logits) == 1:
            logits = logits[0]
        
        labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
        if len(labels) == 1:
            labels = labels[0]
            
        return (loss, logits, labels)
    
    def train(
        self,
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
        ignore_keys_for_eval: Optional[List[str]] = None,
        **kwargs,
    ):
        """
        Main training entry point.

        Args:
            resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`):
                If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of
                :class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in
                `args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present,
                training will resume from the model/optimizer/scheduler states loaded here.
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
            ignore_keys_for_eval (:obj:`List[str]`, `optional`)
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions for evaluation during the training.
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
        """
        resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

        args = self.args
        assert args.gradient_accumulation_steps == 1, print("gradient_accumulation_steps must be 1 !")

        
        # import wandb
        #wandb.login()
        #wandb.init(project="test-project", entity="2nd-order-optim-initialization")

        self.is_in_train = True
        train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)

        # Data loader and number of training steps
        train_dataloader = self.get_train_dataloader()

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
        total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
        # if train_dataset_is_sized:
        num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        if args.max_steps > 0:
            max_steps = args.max_steps
            num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                args.max_steps % num_update_steps_per_epoch > 0
            )
            # May be slightly incorrect if the last batch in the training datalaoder has a smaller size but it's
            # the best we can do.
            num_train_samples = args.max_steps * total_train_batch_size
        else:
            max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
            num_train_epochs = math.ceil(args.num_train_epochs)
            num_train_samples = len(self.train_dataset) * args.num_train_epochs
        # else:
        #     # see __init__. max_steps is set when the dataset has no __len__
        #     max_steps = args.max_steps
        #     # Setting a very large number of epochs so we go as many times as necessary over the iterator.
        #     num_train_epochs = sys.maxsize
        #     num_update_steps_per_epoch = max_steps
        #     num_train_samples = args.max_steps * total_train_batch_size

        self.create_optimizer_and_scheduler(num_training_steps=max_steps)

        self.state = TrainerState()
        self.state.is_hyper_param_search = trial is not None

        # model = self._wrap_model(self.model_wrapped)
        if self.args.n_gpu > 1:
            if self.args.local_rank == -1:
                model = nn.DataParallel(self.model)
            else:
                print("using ddp")
                model = nn.parallel.DistributedDataParallel(self.model)
        else:
            model = self.model
        # print("1", torch.cuda.memory_allocated())
        
        if self.args.optimizer in ["kfac","kfacw"]:
            self.preconditioner = KFAC(model, eps=self.args.eps, update_freq=self.args.update_freq, alpha=self.args.gamma, distributed=(self.args.n_gpu > 1), world_size=self.args.n_gpu)
            print(f"create KFAC with M.A.={self.args.gamma}, update_freq={self.args.update_freq}")
        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        # if model is not self.model:
            # self.model_wrapped = model

        # if delay_optimizer_creation:
            # self.create_optimizer_and_scheduler(num_training_steps=max_steps)

        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

        # important: at this point:
        # self.model         is the Transformers Model
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.

        # Train!
        num_examples = (
            self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps
        )

        logger.info("***** Running training *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Num Epochs = {num_train_epochs}")
        logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
        logger.info(f"  Total optimization steps = {max_steps}")

        self.state.epoch = 0
        start_time = time.time()
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        steps_trained_progress_bar = None

        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
        self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
        self.state.trial_params = hp_params(trial) if trial is not None else None
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()

        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
        tr_loss = torch.tensor(0.0).to(args.device)
        step_loss = torch.tensor(0.0).to(args.device)
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
        self._globalstep_last_logged = self.state.global_step
        model.zero_grad()

        all_losses = None
        all_preds = None
        all_labels = None

        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

        for epoch in range(epochs_trained, num_train_epochs):
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)
            elif isinstance(train_dataloader.dataset, IterableDatasetShard):
                train_dataloader.dataset.set_epoch(epoch)

            epoch_iterator = train_dataloader

            # Reset the past mems state at the beginning of each epoch if necessary.
            # if args.past_index >= 0:
            #     self._past = None

            steps_in_epoch = (
                len(epoch_iterator) if train_dataset_is_sized else args.max_steps * args.gradient_accumulation_steps
            )
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

            epoch_start_time = time.time()
            for step, inputs in enumerate(epoch_iterator):
                
                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

                # if step % 100 == 1:
                #     for name, param in model.named_parameters():
                #         if param.requires_grad:
                #             w = param.clone().detach()
                #             g = param.grad.clone().detach()
                            # if self.args.local_rank in [0, -1]:
                            #     if os.getenv("WANDB_DISABLED", "false") != "true":
                            #         if wandb.run is None:
                            #             wandb.init()
                            #         wandb.log({"grad_norm_"+name: torch.norm(g),
                            #         "weight_norm_"+name: torch.norm(w),
                            #         "max_grad_element_"+name: torch.max(g.abs()),
                            #         "min_grad_element_"+name: torch.min(g.abs()),
                            #         "avg_grad_element_"+name: torch.mean(g.abs())})

                torch.cuda.empty_cache()
                "change here"
                # if isinstance(self.optimizer, SdLBFGS) or isinstance(self.optimizer, LBFGS):
                if isinstance(self.optimizer, LBFGS) or isinstance(self.optimizer, PCA):
                    def closure():
                        self.optimizer.zero_grad()
                        loss, logits, labels = self.training_step(model, inputs)
                        loss.backward()
                        return loss
                    # self.optimizer.step(closure)
                else:
                    self.optimizer.zero_grad()
                    loss, logits, labels = self.training_step(model, inputs)
                    loss.backward()

                # some observation
                if step % 200 == 0:
                    print(f"step:{step}  loss:{loss.item()}\n")
                
                if args.pre_clipping == 1:
                    if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
                        # deepspeed does its own clipping
                        if step % 100 ==1: print("do pre clipping\n")
                        if hasattr(self.optimizer, "clip_grad_norm"):
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
                            self.optimizer.clip_grad_norm(args.max_grad_norm)
                        elif hasattr(model, "clip_grad_norm_"):
                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
                            model.clip_grad_norm_(args.max_grad_norm)
                        else:
                            # Revert to normal clipping otherwise, handling Apex or full precision
                            total_norm, clipped = clip_grad_norm_(
                                amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
                                args.max_grad_norm,
                            )
                            if wandb.run is not None:
                                wandb.log({"total_norm": total_norm, "pre_clipped": int(clipped)})

                # if isinstance(self.optimizer, SdLBFGS) or isinstance(self.optimizer, LBFGS):
                if isinstance(self.optimizer, LBFGS) or isinstance(self.optimizer, PCA):
                    self.optimizer.step(closure)
                elif self.args.optimizer in ["kfac", "kfacw"]:
                    self.preconditioner.step()
                    if args.pre_clipping == 0:            
                        if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
                            # deepspeed does its own clipping
                            if step % 100 ==1: print("do post clipping\n")
                            if hasattr(self.optimizer, "clip_grad_norm"):
                                # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
                                self.optimizer.clip_grad_norm(args.max_grad_norm)
                                if step % 100 ==1: print("do optimizer clipping\n")
                            elif hasattr(model, "clip_grad_norm_"):
                                # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
                                model.clip_grad_norm_(args.max_grad_norm)
                                if step % 100 ==1: print("do model clipping\n")
                            else:
                                # Revert to normal clipping otherwise, handling Apex or full precision
                                total_norm, clipped = clip_grad_norm_(
                                    amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
                                    args.max_grad_norm,
                                )
                                if wandb.run is not None:
                                    wandb.log({"total_norm": total_norm, "post_clipped": int(clipped)})

                                if step % 100 ==1: print("in else clipping\n")

                    self.optimizer.step()
                else:
                    self.optimizer.step()

                if not self.deepspeed and self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                
                # re-evaluate the batch to save result for logging
                with torch.no_grad():
                    loss, logits, labels = self.training_step(model, inputs)
                tr_loss += loss.detach()
                batch_size = find_batch_size(inputs)
                if loss is not None:
                    losses = self._nested_gather(loss.repeat(batch_size))
                    all_losses = losses if all_losses is None else torch.cat((all_losses, losses), dim=0)
                if labels is not None:
                    labels = self._pad_across_processes(labels)
                    labels = self._nested_gather(labels)
                    all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                if logits is not None:
                    logits = self._pad_across_processes(logits)
                    logits = self._nested_gather(logits)
                    # if self.preprocess_logits_for_metrics is not None:
                        # logits = self.preprocess_logits_for_metrics(logits, labels)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
                # self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
                
                
                if (self.state.global_step+1) % args.logging_steps == 0:
                    all_losses = nested_numpify(all_losses)
                    all_preds = nested_numpify(all_preds)
                    all_labels = nested_numpify(all_labels)

                    # Metrics!
                    metric_key_prefix = "train"
                    if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
                        metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
                    else:
                        metrics = {}

                    # To be JSON-serializable, we need to remove numpy types or zero-d tensors
                    metrics = denumpify_detensorize(metrics)
                    # if all_losses is not None:
                    #     metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()

                    # Prefix all keys with metric_key_prefix + '_'
                    for key in list(metrics.keys()):
                        if not key.startswith(f"{metric_key_prefix}_"):
                            metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
                    self.log(metrics)
                    all_losses, all_preds, all_labels = None, None, None
                #################################################################################################
                self.state.global_step += 1
                self.state.epoch = epoch + (step + 1) / steps_in_epoch
                self.control = self.callback_handler.on_step_end(args, self.state, self.control)
                # self.state.eigen_info = None
                if hasattr(self.state, "eigen_info"):
                    delattr(self.state, "eigen_info")
                assert not hasattr(self.state, "eigen_info")
                self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)

            epoch_end_time = time.time()
            epoch_duration = epoch_end_time - epoch_start_time
            if wandb.run is not None:
                wandb.log({"train_epoch_time": epoch_duration})
                # print("logged")
            print(f"Epoch {epoch} training time: {epoch_duration}s")
            # print(f"total steps: {self.state.global_step}")

            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
            if args.pyhessian and self.state.is_world_process_zero:
                if hasattr(self.state, "eigen_info"):
                    print("log eigen info")
                    self.log(self.state.eigen_info)
                    # self.state.eigen_info = None
                    delattr(self.state, "eigen_info")
            assert not hasattr(self.state, "eigen_info")
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)


        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
            elif args.local_rank != -1:
                dist.barrier()

            logger.info(
                f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
            )

            best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
            if os.path.exists(best_model_path):
                # We load the model state dict on the CPU to avoid an OOM error.
                state_dict = torch.load(best_model_path, map_location="cpu")
                # If the model is on the GPU, it still works!
                # self._load_state_dict_in_model(state_dict)
                model.load_state_dict(state_dict, strict=False)
            else:
                logger.warn(
                    f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
                    "on multiple nodes, you should activate `--save_on_each_node`."
                )

            if self.deepspeed:
                self.deepspeed.load_checkpoint(
                    self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
                )

        # add remaining tr_loss

        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        return

    
    def _save_checkpoint(self, model, trial, metrics=None):
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
        # want to save except FullyShardedDDP.
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"

        # Save model checkpoint
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

        
        run_dir = self.args.output_dir
        self.store_flos()

        output_dir = os.path.join(run_dir, checkpoint_folder)
        self.save_model(output_dir)
        if self.deepspeed:
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
            # config `stage3_gather_fp16_weights_on_model_save` is True
            self.deepspeed.save_checkpoint(output_dir)

        # Save optimizer and scheduler
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
            self.optimizer.consolidate_state_dict()

        self.use_amp = False

        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
            with warnings.catch_warnings(record=True) as caught_warnings:
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
        elif is_sagemaker_mp_enabled():
            if smp.dp_rank() == 0:
                # Consolidate the state dict on all processed of dp_rank 0
                opt_state_dict = self.optimizer.state_dict()
                # Save it and the scheduler on the main process
                if self.args.should_save:
                    torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
                    with warnings.catch_warnings(record=True) as caught_warnings:
                        torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                    reissue_pt_warnings(caught_warnings)
                    if self.use_amp:
                        torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
        elif self.args.should_save and not self.deepspeed:
            # deepspeed.save_checkpoint above saves model/optim/sched
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
            with warnings.catch_warnings(record=True) as caught_warnings:
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
            reissue_pt_warnings(caught_warnings)
            if self.use_amp:
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))

        # Determine the new best metric / best model checkpoint
        if metrics is not None and self.args.metric_for_best_model is not None:
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
        # if self.args.should_save:
        #     self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
            if self.args.local_rank == -1:
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()

        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()

        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
        local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
        if local_rank == -1:
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))

        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

        # Maybe delete some older checkpoints.
        if self.args.should_save:
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
