import os
import copy
import json
import math
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
from tqdm import tqdm

import torch
import transformers
import utils
import deepspeed
from torch.utils.data import Dataset
from transformers import Trainer
from typing import List, Optional, Tuple, Union
from datasets import concatenate_datasets, load_from_disk

import deepspeed.comm as dist

from transformers.utils import is_sagemaker_mp_enabled

from torch.optim import Optimizer


IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="")
    neuron_mask_path: Optional[str] = field(default="none")


@dataclass
class DataArguments:
    data_path: str = field(default=None, metadata={"help": "Path to the training data."})
    language: str = field(default='none', metadata={"help": "Language of training data."})

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
    )


def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(
        self,
        data_path: str,
        tokenizer: transformers.PreTrainedTokenizer,
    ):
        super(SupervisedDataset, self).__init__()
        logging.warning("Loading data...")
        self.tokenizer = tokenizer
        self.input_ids, self.attn_mask = [], []
        dataset_dict = []
        for dp in data_path.split(','):
            for data_name in os.listdir(dp):
                dataset_dict.append(load_from_disk(os.path.join(dp, data_name)))
        dataset_dict = concatenate_datasets(dataset_dict)
        self.sources = dataset_dict
        print(self.sources)

    def __len__(self):
        return len(self.sources)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        data = {"input_ids": self.sources[i]["input_ids"]}
        data["labels"] = copy.deepcopy(data["input_ids"])
        return dict(input_ids=data["input_ids"], labels=data["labels"])
        # return dict(input_ids=self.sources[i]["input_ids"])


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    data_args: DataArguments
    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        inputs = {}
        inputs["input_ids"] = torch.tensor([d["input_ids"] for d in instances])
        inputs["labels"] = torch.tensor([d["labels"] for d in instances])
        return inputs


def make_supervised_data_module(
    tokenizer: transformers.PreTrainedTokenizer, data_args
) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = SupervisedDataset(
        tokenizer=tokenizer,
        data_path=data_args.data_path,
    )
    data_collator = DataCollatorForSupervisedDataset(
        data_args=data_args, tokenizer=tokenizer
    )
    return dict(
        train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
    )


class MaskAdamW_HF(Optimizer):
    def __init__(
        self,
        params,
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-6,
        weight_decay: float = 0.0,
        correct_bias: bool = True,
        no_deprecation_warning: bool = False,
    ):
        defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
        super().__init__(params, defaults)

        self.print_details()
        self.world_size = dist.get_world_size()
        self.cur_rank = dist.get_rank()

    def print_details(self):
        print('WARNING: In print details!!!')
        print('Total Group: ', len(self.param_groups))
        print('Keys in group: ', self.param_groups[0].keys())
        sum_params = 0
        for group in self.param_groups:
            sum_params = sum_params + len(group['params'])
        print('sum params: ', sum_params)


    @torch.no_grad()
    def step(self, closure = None):
        """
        Performs a single optimization step.

        Arguments:
            closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        sum_params = 0
        for group in self.param_groups:
            sum_params = sum_params + len(group['params'])
        # print('sum params: ', sum_params)

        for group in self.param_groups:
            if (len(group['params']) == 0):
                continue
            p = group['params'][0]
            mask = group['mask']
            num_neural = mask.size(-1)
            assert(num_neural % self.world_size == 0)
            delta = num_neural // self.world_size
            st_idx = delta * self.cur_rank
            en_idx = delta * (self.cur_rank + 1)
            mask = mask[st_idx: en_idx].to(p.device)

            # print(p.size(), mask.size())
                
            if p.grad is None:
                continue
            grad = p.grad
            if grad.is_sparse:
                raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

            state = self.state[p]

            # State initialization
            if len(state) == 0:
                state["step"] = 0
                # Exponential moving average of gradient values
                state["exp_avg"] = torch.zeros_like(p)
                # Exponential moving average of squared gradient values
                state["exp_avg_sq"] = torch.zeros_like(p)

            exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
            beta1, beta2 = group["betas"]

            state["step"] += 1

            # Decay the first and second moment running average coefficient
            # In-place operations to update the averages at the same time
            exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
            exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
            denom = exp_avg_sq.sqrt().add_(group["eps"])

            step_size = group["lr"]
            if group["correct_bias"]:  # No bias correction for Bert
                bias_correction1 = 1.0 - beta1 ** state["step"]
                bias_correction2 = 1.0 - beta2 ** state["step"]
                step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
            
            p.addcdiv_(exp_avg * mask, denom, value=-step_size)

            if group["weight_decay"] > 0.0:
                p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))

        return loss


class MaskTrainer(Trainer):
    def __init__(
        self,
        model = None,
        args = None,
        data_collator = None,
        train_dataset = None,
        eval_dataset = None,
        tokenizer = None,
        model_init = None,
        compute_metrics = None,
        callbacks = None,
        optimizers = (None, None),
        preprocess_logits_for_metrics = None,
        params_mask = None
    ):
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        )

        self.params_mask = params_mask

    def create_optimizer(self):

        print('is_sagemaker_mp_enabled: ', is_sagemaker_mp_enabled())
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

        if self.optimizer is None:
            decay_parameters = self.get_decay_parameter_names(opt_model)
            optimizer_grouped_parameters = []
            optimizer_grouped_parameters_mask = []

            for n, p in opt_model.named_parameters():
                if (p.requires_grad == False):
                    continue
                if (n in decay_parameters):
                    weight_decay = self.args.weight_decay
                else:
                    weight_decay = 0.0
                # print(n, p.size(), params_mask[n[6:]].size())
                optimizer_grouped_parameters.append({
                    "params": [p],
                    # "params": [p, params_mask[n]],
                    "weight_decay": weight_decay,
                    "mask": self.params_mask[n].view(-1),
                })
                # optimizer_grouped_parameters_mask.append({
                #     "params": [params_mask[n]],
                #     "weight_decay": weight_decay,
                # })

            print(
                'Total Training Parameters: ', 
                len(optimizer_grouped_parameters), 
                len(optimizer_grouped_parameters[0]['params']), 
                optimizer_grouped_parameters[0]['params'][0].size(),
                optimizer_grouped_parameters[0].keys(),
            )

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)
            print(optimizer_kwargs)
            optimizer_cls = MaskAdamW_HF

            # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
            # e.g. for GaLore optimizer.
            if "params" in optimizer_kwargs:
                optimizer_grouped_parameters = optimizer_kwargs.pop("params")

            # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
            # to avoid arguments conflicts.
            if "optimizer_dict" in optimizer_kwargs:
                optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")

            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
            # self.optimizer_mask = optimizer_cls(optimizer_grouped_parameters_mask, **optimizer_kwargs)
            # self.optimizer = optimizer_cls(opt_model.parameters(), **optimizer_kwargs)
            if optimizer_cls.__name__ == "Adam8bit":
                import bitsandbytes

                manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

                skipped = 0
                for module in opt_model.modules():
                    if isinstance(module, nn.Embedding):
                        skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
                        manager.register_module_override(module, "weight", {"optim_bits": 32})

        return self.optimizer



def train():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=True,
    )
    
    data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)

    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
    )
    special_tokens_dict = dict()
    if tokenizer.pad_token is None:
        special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
    if tokenizer.eos_token is None:
        special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
    if tokenizer.bos_token is None:
        special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
    if tokenizer.unk_token is None:
        special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

    smart_tokenizer_and_embedding_resize(
        special_tokens_dict=special_tokens_dict,
        tokenizer=tokenizer,
        model=model,
    )
    
    params_mask = torch.load(model_args.neuron_mask_path, map_location='cpu')

    trainer = MaskTrainer(
        model=model, 
        tokenizer=tokenizer, 
        args=training_args, 
        params_mask=params_mask,
        **data_module
    )
    trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
    trainer.save_state()
    trainer.save_model(output_dir=training_args.output_dir)


if __name__ == "__main__":
    train()
