from src.models.common import CausalLM, CausalLMLayer
from src.trainer.base import MyTrainingArguments
from src.utils.check_transplant import check_transplant
from src.datasets.task_approx import create_dataset_for_task_approx
from src.models.linear_attention import linear_attention

from transformers import Trainer, PreTrainedModel, PretrainedConfig
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.checkpoint import checkpoint
from typing import Optional, Union, Dict, Any, Literal
import sys
from dataclasses import dataclass, field
import math

@dataclass
class MimicArguments(MyTrainingArguments):
    loss_type: Literal["standard", "log", "direct", "task-approx", "softmax"] = "standard"
    kernel_path: Optional[str] = None
    num_train_epochs: Optional[float] = None
    learning_rate: Optional[float] = None
    coef_learning_rate: Optional[float] = None
    logging_steps: int = 50
    save_steps: float = 0.25
    eval_strategy: Literal["no", "steps", "epoch"] = "steps"
    eval_steps: float = 0.25
    label_names: list[str] = field(default_factory=lambda: ["labels"])
    do_transplant_check: bool = True

    def __post_init__(self):
        if self.num_train_epochs is None:
            if self.loss_type in ["standard", "log", "direct", "softmax"]:
                self.num_train_epochs = 1.0
            elif self.loss_type == "task-approx":
                self.num_train_epochs = 3.0
                
        if self.learning_rate is None:
            if self.loss_type in ["standard", "log", "direct", "softmax"]:
                self.learning_rate = 0.02
            elif self.loss_type == "task-approx":
                self.learning_rate = 1e-3

        if self.coef_learning_rate is None:
            if self.loss_type in ["standard", "log", "direct", "softmax"]:
                self.coef_learning_rate = 0.2
            elif self.loss_type == "task-approx":
                self.coef_learning_rate = 1e-3

        if self.loss_type == "task-approx" and self.kernel_path is None:
            self.kernel_path = self.output_dir + "/kernel.pth"

        super().__post_init__()


class MimicLossNet(PreTrainedModel):
    origin_model: CausalLM
    linear_model: CausalLM
    _keys_to_ignore_on_load_missing = None
    _keys_to_ignore_on_save = None

    def __init__(self, 
                 origin_model: CausalLM, 
                 linear_model: CausalLM, 
                 loss_type: str, 
                 gradient_checkpointing: bool = False, 
                 gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None):
        super().__init__(config=PretrainedConfig())
        self.supports_gradient_checkpointing = True

        self.origin_model = origin_model
        self.linear_model = linear_model
        self.loss_type = loss_type
        
        self.gradient_checkpointing = gradient_checkpointing
        self.gradient_checkpointing_kwargs = gradient_checkpointing_kwargs or {}
    
    def _maybe_gradient_checkpointing(self, func, *args):
        if self.gradient_checkpointing:
            return checkpoint(func, *args, **self.gradient_checkpointing_kwargs)
        else:
            return func(*args)
    
    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                kernel: Optional[torch.Tensor] = None,
                kernel_grad: Optional[torch.Tensor] = None,
                indices: Optional[torch.Tensor] = None,
                return_outputs: bool = False,
                **kwargs) -> dict:
        is_eval = return_outputs
        if self.loss_type == "direct" or is_eval:
            output = self.linear_model(input_ids=input_ids, 
                                       attention_mask=attention_mask, 
                                       **kwargs)
            return output.loss
        else:
            if self.loss_type in ["standard", "log", "task-approx"]:
                output_attentions = ["qk", "query", "key"]
            elif self.loss_type == "softmax":
                output_attentions = ["weight", "query", "key"]

            with torch.no_grad():
                origin_out = self.origin_model(input_ids=input_ids, 
                                               attention_mask=attention_mask, 
                                               output_attentions=output_attentions)

            mimic_loss = torch.tensor(0., device=self.origin_model.device)
            n_layers = len(self.origin_model.transformer.h)
            for layer_idx in range(n_layers):
                query = origin_out.attentions[layer_idx]["query"]
                key = origin_out.attentions[layer_idx]["key"]

                qk, attn_weight = None, None
                if "qk" in output_attentions:
                    qk = origin_out.attentions[layer_idx]["qk"]
                if "weight" in output_attentions:
                    attn_weight = origin_out.attentions[layer_idx]["weight"]

                linear_block = self.linear_model.transformer.h[layer_idx]

                if self.loss_type in ["standard", "log"]:
                    layerwise_loss = self._maybe_gradient_checkpointing(self._get_layerwise_diff_loss, 
                                                                       linear_block, attention_mask, qk, query, key)

                elif self.loss_type == "task-approx":
                    _kernel = kernel[:, layer_idx]
                    _kernel_grad = kernel_grad[:, layer_idx]
                    _indices = indices[:, layer_idx]
                    layerwise_loss = self._maybe_gradient_checkpointing(self._get_layerwise_task_approx_loss,
                                                                        linear_block, _kernel, _kernel_grad, _indices, qk, query, key)
                    
                elif self.loss_type == "softmax":
                    layerwise_loss = self._maybe_gradient_checkpointing(self._get_layerwise_softmax_loss,
                                                                        linear_block, attention_mask, attn_weight, query, key)
                    
                mimic_loss += layerwise_loss

            mimic_loss = mimic_loss / n_layers
            return mimic_loss

    def _get_layerwise_diff_loss(self, linear_block: CausalLMLayer, attention_mask: torch.Tensor,
                                 qk: torch.Tensor, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
        attn = linear_block.attn
        
        query_length, key_length = query.size(-2), key.size(-2)
        mask = attn.bias[:, :, key_length - query_length : key_length, :key_length] 
        mask = torch.logical_and(mask, attention_mask[:, None, None, :]) 
        
        if self.loss_type == "standard":
            qk = qk.where(mask, torch.tensor(-float("inf"), device=qk.device))
            shift = qk.amax(dim=-1, keepdim=True) 
            qk = qk - shift

            origin_kernel = torch.exp(qk)

            query_feature = attn.feature_net(query, shift_value=shift)
            key_feature = attn.feature_net(key, shift_value=shift)
            linear_kernel = torch.einsum("bhqf,bhkf->bhqk", query_feature, key_feature)

            diff = origin_kernel - linear_kernel

        elif self.loss_type == "log":
            log_query_feature = attn.feature_net.get_log_features(query) 
            log_key_feature = attn.feature_net.get_log_features(key)
            max_log_query_feature = log_query_feature.amax(dim=-1, keepdim=True)
            max_log_key_feature = log_key_feature.amax(dim=(-2, -1), keepdim=True)
            _query_feature = torch.exp(log_query_feature - max_log_query_feature) 
            _key_feature = torch.exp(log_key_feature - max_log_key_feature) 
            _linear_kernel = torch.einsum("bhqf,bhkf->bhqk", _query_feature, _key_feature) 
            _linear_kernel = _linear_kernel.clamp(min=linear_block.attn.kernel_clip) 
            log_linear_kernel = torch.log(_linear_kernel) + max_log_query_feature + max_log_key_feature

            log_origin_kernel = qk.clamp(min=math.log(linear_block.attn.kernel_clip)) 
            
            diff = log_origin_kernel - log_linear_kernel 

        else:
            raise ValueError(f"Invalid loss type: {self.loss_type}")
        
        diff = diff * mask

        num_nonmasked = mask.sum()
        loss = (diff**2).sum(dim=(0, 2, 3)).mean() / num_nonmasked
        return loss
    
    def _get_layerwise_task_approx_loss(self, 
                                        linear_block: CausalLMLayer, 
                                        kernel: torch.Tensor,
                                        kernel_grad: torch.Tensor,
                                        indices: torch.Tensor,
                                        qk: torch.Tensor,
                                        query: torch.Tensor,
                                        key: torch.Tensor) -> torch.Tensor:
        
        attn = linear_block.attn

        shift = qk.amax(dim=-1, keepdim=True)

        query_feature = attn.feature_net(query, shift_value=shift)
        key_feature = attn.feature_net(key, shift_value=shift)
        linear_kernel = torch.einsum("bhqf,bhkf->bhqk", query_feature, key_feature)
        
        linear_kernel = linear_kernel.flatten(-2, -1)
        linear_kernel = torch.gather(linear_kernel, -1, indices)

        diff = linear_kernel - kernel
        loss = torch.einsum("bhk,bhk->bh", kernel_grad, diff)
        loss += torch.einsum("bhk,bhk,bhk->bh", diff, kernel_grad ** 2, diff) / 2
        loss /= kernel_grad.size(-1)
        loss = loss.mean()
        return loss
    
    def _get_layerwise_softmax_loss(self, linear_block: CausalLMLayer, attention_mask: torch.Tensor,
                                    attn_weight: torch.Tensor, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
        attn = linear_block.attn

        _attention_mask = attention_mask[:, None, None, :].float()
        _attention_mask = torch.where(_attention_mask == 0, float('-inf'), 0.0)
        _, attn_info = linear_attention(
            attn.feature_net, attn.attn_dropout, query, key, torch.zeros_like(query), _attention_mask, head_mask=None, 
            cache=None, adaptive_shift=attn.config.adaptive_shift, bias=attn.bias, kernel_clip=attn.kernel_clip
        )
        linear_attn_weight = attn_info["weight"]

        log_linear_attn_weight = torch.log(linear_attn_weight.clamp(min=1e-8))
        loss = - (attn_weight * log_linear_attn_weight).sum(dim=-1).mean()
        
        return loss


class MimicTrainer(Trainer):
    args: MimicArguments
    model: MimicLossNet
    origin_model: CausalLM
    _dataloader: Optional[DataLoader]

    def __init__(self, 
                 origin_model: CausalLM, 
                 linear_model: CausalLM, 
                 args: MimicArguments,
                 data_collator: nn.Module,
                 train_dataset: Dataset, 
                 *_args, **kwargs):
        if kwargs.get("model") is not None:
            print("[WARNING] The `model` argument will be ignored.")
            kwargs.pop("model")
        if kwargs.get("compute_metrics") is not None:
            print("[WARNING] The `compute_metrics` argument will be ignored.")
            kwargs.pop("compute_metrics")

        if linear_model.config.use_linear_attn:
            linear_model.transplant(origin_model)
        else:
            linear_model = origin_model

        self._dataloader = None

        model = MimicLossNet(origin_model, linear_model, args.loss_type,
                             gradient_checkpointing=args.gradient_checkpointing,
                             gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs
                             ).to(origin_model.device)

        super().__init__(model, args, data_collator=data_collator, train_dataset=train_dataset, 
                         *_args, **kwargs)

        if linear_model.config.use_linear_attn and args.do_transplant_check:
            _dataset = train_dataset.train_test_split(train_size=0.01)["train"]
            _dataset = self._remove_unused_columns(_dataset)
            check_transplant(linear_model, origin_model, 
                             dataset=_dataset, data_collator=data_collator)

        assert linear_model is model.linear_model
    
    def get_train_dataloader(self) -> DataLoader:
        if self._dataloader is not None:
            return self._dataloader
        dataloader = super().get_train_dataloader()
        if self.args.loss_type == "task-approx":
            dataloader = create_dataset_for_task_approx(dataloader, self.model.origin_model, save_path=self.args.kernel_path)
            dataloader = self.accelerator.prepare(dataloader)
        self._dataloader = dataloader
        return dataloader

    def training_step(self, model, 
                      inputs: Dict[str, Union[torch.Tensor, Any]], 
                      num_items_in_batch=None) -> torch.Tensor:
        
        out = super().training_step(model, inputs, num_items_in_batch)
        
        sys.stdout.flush()
        return out
    
    def train(self, *args, **kwargs):
        num_learnable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        if num_learnable_params == 0:
            print("[Message] No learnable parameters. The training will be skipped.", flush=True)
            return
        else:
            print(f"[Message] The number of learnable parameters: {num_learnable_params}", flush=True)
            super().train(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        loss = model(**inputs, return_outputs=return_outputs)
        if return_outputs:
            return (loss, {"logits": None})
        else:
            return loss
    
    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, 
                                 epoch, ignore_keys_for_eval, start_time):
        if self.control.should_log and self.state.global_step > self._globalstep_last_logged:

            logs: Dict[str, float] = {}

            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

            tr_loss -= tr_loss

            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 8)
            if grad_norm is not None:
                logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
            logs["learning_rate"] = self._get_learning_rate()

            self._total_loss_scalar += tr_loss_scalar
            self._globalstep_last_logged = self.state.global_step
            self.store_flos()

            self.log(logs, start_time)

        if self.control.should_evaluate:
            self._evaluate(trial, ignore_keys_for_eval)

        if self.control.should_save:
            if isinstance(model, nn.DataParallel):
                linear_model = model.module.linear_model
            else:
                linear_model = model.linear_model
            self._save_checkpoint(linear_model, trial)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)