# Adapted from https://github.com/huggingface/alignment-handbook 

from copy import deepcopy
from typing import Any, Optional, Tuple, Union
import numpy as np
import threading
import optree
import dataclasses
from tqdm import tqdm
import os
import subprocess
import sys

from collections import OrderedDict
from optree.typing import PyTreeTypeVar
from typing_extensions import TypeAlias
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate.utils import is_deepspeed_available
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
    DataCollator,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Trainer,
    TrainingArguments,
    get_scheduler,
    CONFIG_NAME,
    WEIGHTS_NAME,
)
from transformers.tokenization_utils import BatchEncoding
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from torch.utils.data.distributed import DistributedSampler
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
import torch.distributed as dist
from accelerate import Accelerator

from trl.import_utils import is_wandb_available
from trl.models import PreTrainedModelWrapper
from trl.trainer.utils import disable_dropout_in_model

TensorTree: TypeAlias = PyTreeTypeVar('TensorTree', torch.Tensor)

if is_deepspeed_available():
    import deepspeed


def get_subclasses(cls: type, memo: set[type] | None = None) -> [type, None, None]:
    """Get all subclasses of a class recursively."""
    if memo is None:
        memo = set()

    for subclass in cls.__subclasses__():
        if subclass in memo:
            continue

        memo.add(subclass)
        yield subclass
        yield from get_subclasses(subclass, memo=memo)


__PYTREE_INITIALIZED = False
__PYTREE_REGISTRY_LOCK = threading.Lock()


def __initialize_pytree_registry_once() -> None:
    # pylint: disable-next=import-outside-toplevel,unused-import
    global __PYTREE_INITIALIZED  # pylint: disable=global-statement
    if __PYTREE_INITIALIZED:
        return

    with __PYTREE_REGISTRY_LOCK:
        if __PYTREE_INITIALIZED:
            return

        optree.register_pytree_node(
            BatchEncoding,
            lambda batch_encoding: (
                [batch_encoding.data],
                {'encoding': batch_encoding.encodings, 'n_sequences': batch_encoding.n_sequences},
            ),
            lambda metadata, children: BatchEncoding(children[0], **metadata),
            namespace='LoT',
        )
        optree.register_pytree_node(
            ModelOutput,
            lambda model_output: (model_output.values(), model_output.keys(), model_output.keys()),
            lambda keys, values: ModelOutput(OrderedDict(zip(keys, values))),
            namespace='LoT',
        )

        for model_output_class in filter(dataclasses.is_dataclass, get_subclasses(ModelOutput)):
            optree.register_pytree_node(
                model_output_class,
                lambda model_output: ([dataclasses.asdict(model_output)], type(model_output)),
                lambda metadata, children: metadata(**children[0]),
                namespace='LoT',
            )

        __PYTREE_INITIALIZED = True


def to_device(batch: TensorTree, device: torch.device | str | int | None) -> TensorTree:
    """Move a batch of tensors to a device."""
    if not __PYTREE_INITIALIZED:
        __initialize_pytree_registry_once()
    if device is None:
        return batch
    return optree.tree_map(lambda x: x.to(device), batch, namespace='LoT')


def kl_div_logits(p, q, T):
    loss_func = nn.KLDivLoss(reduction = 'batchmean', log_target=True)
    loss = loss_func(F.log_softmax(p/T, dim=-1), F.log_softmax(q/T, dim=-1)) * T * T
    return loss


def get_all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
    """Perform all-reduce operation on a tensor cross all ranks and return the mean."""
    if dist.is_initialized():
        dist.all_reduce(tensor, op=dist.ReduceOp.AVG)
    return tensor


def is_main_process() -> bool:
    """Check if the current process is the main process."""
    return not dist.is_initialized() or dist.get_rank() == 0


class LoTTrainer(Trainer):
    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module, str] = None,
        student: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
        args: TrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        label_pad_token_id: int = -100,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        is_encoder_decoder: Optional[bool] = None,
        disable_dropout: bool = True,
        accelerator: Optional[Accelerator] = None,
    ):

        if getattr(args, "gradient_checkpointing", False):
            # For backward compatibility with older versions of transformers
            if hasattr(model, "enable_input_require_grads"):
                model.enable_input_require_grads()
                if student:
                    student.enable_input_require_grads()
            else:
                def make_inputs_require_grad(module, input, output):
                    output.requires_grad_(True)
                model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
                if student:
                    student.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

        if model is not None:
            self.is_encoder_decoder = model.config.is_encoder_decoder
        elif is_encoder_decoder is None:
            raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
        else:
            self.is_encoder_decoder = is_encoder_decoder

        if student:
            self.student = student
        else:
            self.student = None

        if disable_dropout:
            disable_dropout_in_model(model)
            if self.student is not None:
                disable_dropout_in_model(self.student)

        self.label_pad_token_id = label_pad_token_id

        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
        )

        self.init_datasets()
        self.args.num_update_steps_per_epoch = (
            len(self.train_dataloader) + self.args.gradient_accumulation_steps - 1
        ) // self.args.gradient_accumulation_steps
        self.args.total_training_steps = self.args.num_train_epochs * self.args.num_update_steps_per_epoch
        self.args.train_step = 0
        self.args.original_alpha = self.args.alpha
        self.device = torch.device('cuda', self.args.local_rank)

        self.model = self._prepare_deepspeed(self.model, offload='none')
        if student:
            self.student = self._prepare_deepspeed(self.student, offload='cpu')


    def _prepare_deepspeed(self, model: PreTrainedModelWrapper, offload='none'):
        # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
        deepspeed_plugin = self.accelerator.state.deepspeed_plugin
        config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
        self.ds_config = config_kwargs
        print('ds_config', self.ds_config)

        if model is not None:
            if offload == 'cpu':
                self.ds_config['zero_optimization']['offload_optimizer']['device'] = 'cpu'
                optimizer = DeepSpeedCPUAdam(
                    model.parameters(),
                    lr=self.args.learning_rate,
                    betas=config_kwargs["optimizer"]['params']['betas'],
                )
            else:
                self.ds_config['zero_optimization']['offload_optimizer']['device'] = 'none'
                optimizer = FusedAdam(
                    model.parameters(),
                    lr=self.args.learning_rate,
                    betas=config_kwargs["optimizer"]['params']['betas'],
                )
            num_warmup_steps = self.args.warmup_steps
            lr_scheduler = get_scheduler(
                name=self.args.lr_scheduler_type,
                optimizer=optimizer,
                num_warmup_steps=num_warmup_steps,
                num_training_steps=self.args.total_training_steps,
            )
            self.ds_config['scheduler']['params']['warmup_num_steps'] = self.args.warmup_steps
            self.ds_config['scheduler']['params']['total_num_steps'] = self.args.total_training_steps

            if hasattr(model, "config"):
                hidden_size = (
                    max(model.config.hidden_sizes)
                    if getattr(model.config, "hidden_sizes", None)
                    else getattr(model.config, "hidden_size", None)
                )
                if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
                    # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
                    # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
                    config_kwargs.update(
                        {
                            "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
                            "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
                            "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
                        }
                    )
        # If ZeRO-3 is used, we shard both the active and reference model.
        # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
        if config_kwargs["zero_optimization"]["stage"] != 3:
            config_kwargs["zero_optimization"]["stage"] = 0
        model, *_ = deepspeed.initialize(
                model=model,
                optimizer=optimizer,
                args=self.args,
                config=self.ds_config,
                lr_scheduler=lr_scheduler,
                dist_init_required=True,
            )
        
        return model
    

    def init_datasets(self) -> None:
        """Initialize training and evaluation datasets."""
        self.train_dataloader = DataLoader(
            self.train_dataset,
            collate_fn=self.data_collator,
            sampler=DistributedSampler(self.train_dataset, shuffle=True),
            batch_size=self.args.per_device_train_batch_size,
        )


    def lot_loss(
        self,
        teacher_logits: torch.FloatTensor,
        student_logits: torch.FloatTensor,
        labels: torch.LongTensor,
        average_log_prob: bool = False,
    ) -> torch.FloatTensor:
        """Compute the log probabilities of the given labels under the given logits.

        Args:
            logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
            labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
            average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.

        Returns:
            A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
        """
        if teacher_logits.shape[:-1] != labels.shape:
            raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")

        if not self.is_encoder_decoder:
            labels = labels[:, 1:].clone()
            teacher_logits = teacher_logits[:, :-1, :]
            if self.student:
                student_logits = student_logits[:, :-1, :]

        # dummy token; we'll ignore the losses on these tokens later
        labels[labels == self.label_pad_token_id] = 0

        if self.student and self.args.train_step >= self.args.start_step:
            self.args.alpha = self.args.original_alpha * (0.5 + self.args.train_step / self.args.total_training_steps + np.random.normal(0, 0.01))
            teacher_lot_loss = self.args.alpha*kl_div_logits(teacher_logits, student_logits.detach(), self.args.T)
            self.args.alpha = self.args.original_alpha * (0.5 + self.args.train_step / self.args.total_training_steps + np.random.normal(0, 0.01))
            student_lot_loss = self.args.alpha*kl_div_logits(student_logits, teacher_logits.detach(), self.args.T)
        else:
            teacher_lot_loss = torch.zeros((1,), dtype=torch.float32, device=self.device)
            student_lot_loss = torch.zeros((1,), dtype=torch.float32, device=self.device)
        
        return teacher_lot_loss, student_lot_loss
    

    def ce_loss(
        self,
        model,
        input_ids: torch.LongTensor,  # size = (B, L)
        labels: torch.LongTensor,  # size = (B, L)
        attention_mask: torch.BoolTensor,  # size = (B, L)
    ) -> dict[str, torch.Tensor]:
        """Loss function for supervised finetuning."""
        outputs: CausalLMOutputWithPast = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        return (outputs.logits, outputs.loss)


    def train_step(
        self,
        input_ids: torch.LongTensor,  # size = (B, L)
        labels: torch.LongTensor,  # size = (B, L)
        attention_mask: torch.BoolTensor,  # size = (B, L)
    ) -> dict[str, Any]:
        """Performs a single training step.

        Args:
            input_ids (torch.LongTensor): input ids for causal inputs to complete with.
            labels (torch.LongTensor): labels for the full sequence.
            attention_mask (torch.BoolTensor): attention mask for the labels.

        Returns:
            dict[str, Any]: training loss, learning rate
        """
        teacher_logits, teacher_ce_loss = self.ce_loss(
            model=self.model,
            input_ids=input_ids,
            labels=labels,
            attention_mask=attention_mask,
        )
        if self.student:
            student_logits, student_ce_loss = self.ce_loss(
                model=self.student,
                input_ids=input_ids,
                labels=labels,
                attention_mask=attention_mask,
            )
        else:
            student_logits = None
            student_ce_loss = torch.zeros_like(teacher_ce_loss)
        teacher_lot_loss, student_lot_loss = self.lot_loss(
            teacher_logits=teacher_logits,
            student_logits=student_logits,
            labels=labels,
        )
        teacher_loss = teacher_ce_loss + teacher_lot_loss
        student_loss = student_ce_loss + student_lot_loss

        self.model.backward(teacher_loss)
        self.model.step()
        teacher_loss = get_all_reduce_mean(teacher_loss)
        teacher_lot_loss = get_all_reduce_mean(teacher_lot_loss)
        if self.student:
            self.student.backward(student_loss)
            self.student.step()
        student_loss = get_all_reduce_mean(student_loss)
        student_lot_loss = get_all_reduce_mean(student_lot_loss)

        return {
            'train/loss': teacher_loss.item(),
            'train/lot_loss': teacher_lot_loss.item(),
            'train/lr': self.model.optimizer.param_groups[0]['lr'],
        }


    def train(self):
        progress_bar = tqdm(
            total=self.args.num_train_epochs * len(self.train_dataloader),
            desc=f'Training 1/{self.args.num_train_epochs} epoch',
            position=0,
            leave=True,
            disable=not is_main_process(),
        )
        current_cnt = 0
        for epoch in range(self.args.num_train_epochs):
            self.model.train()
            if self.student:
                self.student.train()

            for batch in self.train_dataloader:
                info = self.train_step(**to_device(batch, self.device))
                torch.cuda.empty_cache()
                current_cnt += 1
                if current_cnt % self.args.gradient_accumulation_steps == 0:
                    self.args.train_step += 1
                progress_bar.set_description(
                    f'Training {epoch + 1}/{self.args.num_train_epochs} epoch '
                    f'(loss {info["train/loss"]:.4f})'
                    f'(lot_loss {info["train/lot_loss"]:.6f})'
                    f'(lr {info["train/lr"]:.8f})'
                )
                progress_bar.update(1)
                info['train/epoch'] = current_cnt / len(self.train_dataloader)

                if current_cnt > 0 and current_cnt % (self.args.save_steps * self.args.gradient_accumulation_steps) == 0:
                    print(f'Saving checkpoint at step {self.args.train_step} ...')
                    self.model.save_checkpoint(self.args.output_dir+'/teacher', tag=self.args.train_step)
                    if self.student:
                        self.student.save_checkpoint(self.args.output_dir+'/student', tag=self.args.train_step)
                    print('Checkpoint saved.')

            self.model.tput_timer.update_epoch_count()


    def save(
        self,
        model: deepspeed.DeepSpeedEngine | None = None,
        student: deepspeed.DeepSpeedEngine | None = None,
        ds_config: dict[str, Any] | None = None,
    ) -> None:
        """Save model and tokenizer in Hugging Face format."""
        dist.barrier()

        if model is None:
            model = self.model  # pylint: disable=no-member
        if student is None:
            student = self.student  # pylint: disable=no-member
        if ds_config is None:
            ds_config = self.ds_config  # pylint: disable=no-member
        self.args.output_dir = self.args.output_dir+'/final'
        os.makedirs(self.args.output_dir, exist_ok=True)
        os.makedirs(self.args.output_dir+'/teacher', exist_ok=True)
        os.makedirs(self.args.output_dir+'/student', exist_ok=True)
        print(f'Saving model to "{self.args.output_dir}" ...')

        teacher_output_config_file = os.path.join(self.args.output_dir, 'teacher', CONFIG_NAME)
        teacher_model_to_save: PreTrainedModel = getattr(model, 'module', model)
        if is_main_process():
            teacher_model_to_save.config.to_json_file(teacher_output_config_file)
            self.tokenizer.save_pretrained(self.args.output_dir+'/teacher')
        
        if student:
            student_output_config_file = os.path.join(self.args.output_dir, 'student', CONFIG_NAME)
            student_model_to_save: PreTrainedModel = getattr(student, 'module', student)
            if is_main_process():
                student_model_to_save.config.to_json_file(student_output_config_file)
                self.tokenizer.save_pretrained(self.args.output_dir+'/student')
        
        # Save model checkpoint
        if ds_config['zero_optimization']['stage'] >= 2:
            print('Saving DeepSpeed Checkpoints...')
            model.save_checkpoint(self.args.output_dir+'/teacher')
            if student:
                student.save_checkpoint(self.args.output_dir+'/student')
            print('Converting DeepSpeed Checkpoints to Hugging Face format...')
            if is_main_process():
                subprocess.check_call(
                    [sys.executable, 'zero_to_fp32.py', '.', WEIGHTS_NAME],  # noqa: S603
                    cwd=self.args.output_dir+'/teacher',
                )
                if student:
                    subprocess.check_call(
                        [sys.executable, 'zero_to_fp32.py', '.', WEIGHTS_NAME],  # noqa: S603
                        cwd=self.args.output_dir+'/student',
                    )
            dist.barrier()
        else:
            print('Saving Hugging Face Checkpoints...')
            if is_main_process():
                teacher_model_to_save.save_pretrained(self.args.output_dir+'/teacher', is_main_process=True)
                if student:
                    student_model_to_save.save_pretrained(self.args.output_dir+'/student', is_main_process=True)

        print('Model saved!')
