import copy
import deepspeed
from transformers import Trainer
from torch.optim import Adam, SGD
from .losses import get_loss, idk_loss, me_loss, gd_loss, ap_loss

# from trainer.custom_optimizer import AdamDecouple, AdamWDecouple, OptimizedAdamWDecouple
from trainer.custom_optimizer import AdamWDecouple8bit, AdamWDecoupleNormal

import torch
import torch.nn as nn
import torch.optim as optim
import contextlib
import copy
import functools
import glob
import importlib.metadata
import inspect
import json
import math
import os
import random
import re
import shutil
import sys
import tempfile
import time
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)

_smdistributed_available = importlib.util.find_spec("smdistributed") is not None


def is_sagemaker_mp_enabled():
    # Get the sagemaker specific mp parameters from smp_options variable.
    smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
    try:
        # Parse it and check the field "partitions" is included, it is required for model parallel.
        smp_options = json.loads(smp_options)
        if "partitions" not in smp_options:
            return False
    except json.JSONDecodeError:
        return False

    # Get the sagemaker specific framework parameters from mpi_options variable.
    mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
    try:
        # Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
        mpi_options = json.loads(mpi_options)
        if not mpi_options.get("sagemaker_mpi_enabled", False):
            return False
    except json.JSONDecodeError:
        return False
    # Lastly, check if the `smdistributed` module is present.
    return _smdistributed_available


class CustomTrainerForgettingDO(Trainer):
    def __init__(
        self, optim_cfg="dual_adam", forget_lr=1e-5, retain_lr=1e-5, *args, **kwargs
    ):
        self.loss_type = kwargs.pop("loss_type")
        self.ref_model = kwargs.pop("ref_model")
        self.forget_coeff = kwargs.pop("forget_coeff")
        self.regularization_coeff = kwargs.pop("regularization_coeff")
        self.beta = kwargs.pop("beta")

        # Optimizer type (adam, sgd, dual_adam, dual_sgd)
        self.optim_cfg = optim_cfg
        self.forget_lr = forget_lr
        self.retain_lr = retain_lr
        self.dual = True

        self.forget_lr_ratio = self.forget_lr / self.retain_lr
        super(CustomTrainerForgettingDO, self).__init__(*args, **kwargs)

        # Prepare the reference model with DeepSpeed
        self.ref_model = self.e_prepare_deepspeed(self.ref_model)

        # import pdb;pdb.set_trace()
        # Initialize optimizers based on the `optim` option

        if self.optim_cfg == "dual_adam" or self.optim_cfg == "dual_adam8bit":
            # self.optimizer_retain = optim.AdamW(opt_model.parameters(), lr=self.retain_lr, weight_decay=0.01)
            # self.optimizer_forget = optim.AdamW(opt_model.parameters(), lr=self.forget_lr, weight_decay=0.01)
            # self.optimizer = AdamWDecouple(opt_model.parameters())
            print(
                f"Create forget optimizer using type {self.optim_cfg} with lr {self.forget_lr}"
            )
            # self.optimizer_forget = self.accelerator.prepare(self.optimizer_forget)
            # self.optimizer_forget = self.create_optimizer()
            # TODO: if using sgd
        elif self.optim_cfg == "single_alo":
            # In this case, using default optimizer
            # self.optimizer_alo = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
            pass
        else:
            raise ValueError(f"Unsupported optimizer type: {self.optim}")

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
        """

        if self.optim_cfg == "dual_adam":
            # self.optimizer_retain = optim.AdamW(opt_model.parameters(), lr=self.retain_lr, weight_decay=0.01)
            # self.optimizer_forget = optim.AdamW(opt_model.parameters(), lr=self.forget_lr, weight_decay=0.01)
            # self.optimizer = AdamWDecouple(opt_model.parameters())
            print("32bit optim > Using forget ratio: ", self.forget_lr_ratio)
            print("Retain lr: ", self.retain_lr)
            # self.optimizer = AdamWDecouple(
            #     self.model.parameters(), lr=1e-5, lr_forget_ratio=self.forget_lr_ratio
            # )
            self.optimizer = AdamWDecoupleNormal(
                self.model.parameters(),
                lr=self.retain_lr,
                lr_ratio_1=self.forget_lr_ratio,
            )

        elif self.optim_cfg == "dual_adam8bit":
            print("8bit optim > Using forget ratio: ", self.forget_lr_ratio)
            print("Retain lr: ", self.retain_lr)
            self.optimizer = AdamWDecouple8bit(
                self.model.parameters(),
                lr=self.retain_lr,
                lr_ratio_1=self.forget_lr_ratio,
            )
            # self.optimizer.
        else:
            self.create_optimizer()

        # if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
        #     # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
        #     optimizer = self.optimizer.optimizer
        # else:
        optimizer = self.optimizer
        self.create_scheduler(
            num_training_steps=num_training_steps, optimizer=optimizer
        )

    def compute_loss(self, model, inputs, return_outputs=False):
        # Compute the two losses
        forget_loss, regularization_loss = get_loss(
            model, self.ref_model, inputs, self.loss_type, self.beta
        )
        computed_forget_loss = self.forget_coeff * forget_loss
        computed_regularization_loss = self.regularization_coeff * regularization_loss

        loss = (
            self.forget_coeff * forget_loss
            + self.regularization_coeff * regularization_loss
        )
        if not self.dual:
            return (loss, None) if return_outputs else loss
        else:
            return (
                ((computed_forget_loss, computed_regularization_loss), None)
                if return_outputs
                else (computed_forget_loss, computed_regularization_loss)
            )

    def compute_loss_by_type(self, model, inputs, type="forget"):
        # Compute the two losses
        if type == "forget":
            forget_type = self.loss_type.split("+")[0]
            # print(forget_type)
            forget_loss, regularization_loss = get_loss(
                model, self.ref_model, inputs, forget_type, self.beta
            )
            computed_forget_loss = self.forget_coeff * forget_loss
            return computed_forget_loss

        elif type == "retain":
            retain_type = self.loss_type.split("+")[1]
            # print(retain_type)
            forget_loss, regularization_loss = get_loss(
                model, self.ref_model, inputs, retain_type, self.beta
            )
            computed_regularization_loss = (
                self.regularization_coeff * regularization_loss
            )
            return computed_regularization_loss
        else:
            raise NotImplementedError()

    def training_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> torch.Tensor:
        model.train()
        inputs = self._prepare_inputs(inputs)
        # if is_sagemaker_mp_enabled():
        #     loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
        #     return loss_mb.reduce_mean().detach().to(self.args.device)
        if not self.dual:
            with self.compute_loss_context_manager():
                loss = self.compute_loss(model, inputs)

            del inputs
            if (
                self.args.torch_empty_cache_steps is not None
                and self.state.global_step % self.args.torch_empty_cache_steps == 0
            ):
                torch.cuda.empty_cache()

            kwargs = {}

            if self.args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training

            self.accelerator.backward(loss, **kwargs)

            return loss.detach() / self.args.gradient_accumulation_steps

        else:
            kwargs = {}
            forget_loss = 0
            normal_loss = 0

            if self.optim_cfg == "dual_adam8bit" or self.optim_cfg == "dual_adam":

                self.optimizer.zero_grad()

                forget_loss = self.compute_loss_by_type(model, inputs, type="forget")
                self.accelerator.backward(forget_loss, **kwargs)
                self.optimizer.step()

                model.zero_grad()

                self.optimizer.zero_grad()
                normal_loss = self.compute_loss_by_type(model, inputs, type="retain")
                self.accelerator.backward(normal_loss, **kwargs)

            elif self.optim_cfg == "single_alo":
                # In this case, using default optimizer
                # import pdb;pdb.set_trace()
                self.optimizer.zero_grad()
                forget_loss = self.compute_loss_by_type(model, inputs, type="forget")
                self.accelerator.backward(forget_loss, **kwargs)
                self.optimizer.step()

                self.optimizer.zero_grad()
                normal_loss = self.compute_loss_by_type(model, inputs, type="retain")
                self.accelerator.backward(normal_loss, **kwargs)

            loss_new = (
                self.forget_coeff * forget_loss
                + self.regularization_coeff * normal_loss
            )
            # import pdb;pdb.set_trace()
            return loss_new.detach() / self.args.gradient_accumulation_steps

    def e_prepare_deepspeed(self, model):
        deepspeed_plugin = self.accelerator.state.deepspeed_plugin
        config_kwargs = copy.deepcopy(deepspeed_plugin.deepspeed_config)

        if model is not None:
            if hasattr(model, "config"):
                hidden_size = (
                    max(model.config.hidden_sizes)
                    if getattr(model.config, "hidden_sizes", None)
                    else getattr(model.config, "hidden_size", None)
                )
                if (
                    hidden_size is not None
                    and config_kwargs["zero_optimization"]["stage"] == 3
                ):
                    config_kwargs.update(
                        {
                            "zero_optimization.reduce_bucket_size": hidden_size
                            * hidden_size,
                            "zero_optimization.stage3_param_persistence_threshold": 10
                            * hidden_size,
                            "zero_optimization.stage3_prefetch_bucket_size": 0.9
                            * hidden_size
                            * hidden_size,
                        }
                    )

        # If ZeRO-3 is used, we shard both the active and reference model.
        if config_kwargs["zero_optimization"]["stage"] != 3:
            config_kwargs["zero_optimization"]["stage"] = 0

        # Disable optimizer in DeepSpeed since we are using custom optimizers
        config_kwargs["optimizer"] = {"type": None}

        model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
        model.eval()
        for param in model.parameters():
            param.requires_grad = False

        return model
