# Part of the implementation is borrowed from huggingface/transformers.
import os
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from peft import PeftModel
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from transformers import Seq2SeqTrainer as HfSeq2SeqTrainer
from transformers import Trainer as HfTrainer
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.utils import (
    is_peft_available, 
    is_sagemaker_mp_enabled, 
    is_torch_npu_available, 
    is_torch_mlu_available, 
    is_torch_xpu_available, 
    is_torch_musa_available, 
    is_torch_mps_available,
    logging,
)
from transformers.training_args import OptimizerNames

from swift.plugin import MeanMetric, compute_acc
from swift.utils import JsonlWriter, Serializer, use_torchacc
from swift.utils.torchacc_utils import ta_trim_graph
from .arguments import Seq2SeqTrainingArguments
from .mixin import SwiftMixin
from .torchacc_mixin import TorchAccMixin

from deepspeed.utils import (
    safe_get_full_fp32_param,
    safe_set_full_fp32_param,
    safe_get_local_fp32_param,
    safe_set_local_fp32_param,
    safe_get_full_grad,
    safe_set_full_grad,
)
from deepspeed.runtime.zero.utils import is_zero_param

logger = logging.get_logger(__name__)

global top_param_info_json
global top_param_global_json
top_param_info_json = "mix_seq/math_union_sub_csqa_intersect_top_100_params.json"
top_param_global_json = "math_global_top_params.json"


def load_frozen_rows_global(file_path):
    import json
    top_named_param_rows = []
    with open(file_path, "r") as f:
        data = json.load(f)
        for idx, top_info in data.items():
            layer_name = top_info[0]
            top_row_indices = top_info[1]
            top_named_param_rows.append((layer_name, top_row_indices))
    return top_named_param_rows

def load_frozen_param_rows_global(file_path):
    import json
    top_named_param_index_rows = []
    with open(file_path, "r") as f:
        rows = json.load(f)
        for row in rows:
            param_name = row["param_name"]
            index = row["row_indices"]
            row_param = row["selected_row"]
            param_shape = row["param_shape"]
            top_named_param_index_rows.append((param_name, index, row_param, param_shape))
    return top_named_param_index_rows

def zero_out_grad_hook(grad, mask):
    return grad * mask

def is_row_frozen(track_dict, param_name, row_index):
    return row_index in track_dict.get(param_name, [])

def freeze_param_rows(param, row_indices, track_dict, param_name):
    mask = torch.ones_like(param, device=param.device)
    for idx in row_indices:
        mask[idx, :] = 0
    
    param.register_hook(lambda grad: zero_out_grad_hook(grad, mask))
    track_dict[param_name] = row_indices

def freeze_param_matrix(param, track_dict, param_name):
    mask = torch.zeros_like(param, device=param.device)
    param.register_hook(lambda grad: zero_out_grad_hook(grad, mask))
    track_dict[param_name] = list(range(param.shape[0]))
    


def freeze_rows_in_top_param(model, top_param_info_json=top_param_info_json):
    import json
    with open(top_param_info_json, 'r') as f:
        top_param_dict = json.load(f)
    
    for param_name, param in model.named_parameters():
        name = param_name.split(".weight")[0]
        if name in top_param_dict:
            frozen_rows = top_param_dict[name]
            for row in frozen_rows:
                param.data[row].requires_grad = False

def freeze_global_top_param(model, top_param_global_json=top_param_global_json):
    top_named_param_rows = load_frozen_rows_global(top_param_global_json)
    for top_named_param_row in top_named_param_rows:
        param_name, top_row_indices = top_named_param_row
        print(f"parameter name: {param_name}, top row indices: {top_row_indices}")
        for name, param in model.named_parameters():
            print(f"Parameter {name}'s shape: {param.data.shape}")
            if name == f"{param_name}.weight":
                param.data[top_row_indices].requires_grad = False

def freeze_global_top_param_ds(model, top_param_global_json=top_param_global_json):
    top_named_param_rows = load_frozen_rows_global(top_param_global_json)
    for top_named_param_row in top_named_param_rows:
        param_name, top_row_indices = top_named_param_row
        import deepspeed
        for name, param in model.named_parameters():
            context = deepspeed.zero.GatheredParameters(param)
            if name == f"{param_name}.weight":
                with context:
                    param.data[top_row_indices].requires_grad = False

def reset_param_rows(model, top_param_global_json=top_param_global_json):
    print(f"Resetting parameter rows in selective layers...")
    top_named_param_rows = load_frozen_rows_global(top_param_global_json)
    for top_named_param_info in top_named_param_rows:
        param_name, top_row_indice = top_named_param_info
        
        for name, param in model.named_parameters():
            if name == f"{param_name}.weight":
                cur_grad = safe_get_full_grad(param)
                if is_zero_param(param):
                    assert cur_grad.shape == param.ds_shape
                    
                zero_tensor = torch.zeros(param.ds_shape) if is_zero_param(param) else torch.zeros(param.shape)
                cur_grad[top_row_indice] = zero_tensor[top_row_indice]
                safe_set_full_grad(param, cur_grad)




class Trainer(SwiftMixin, HfTrainer):
    pass


class Seq2SeqTrainer(TorchAccMixin, SwiftMixin, HfSeq2SeqTrainer):
    args: Seq2SeqTrainingArguments

    def __init__(self, top_param_info_json=top_param_info_json, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        if top_param_info_json is not None:
            self._freeze_named_param_rows(self.model)
        
        if self.args.predict_with_generate:
            from swift.llm import PtEngine
            self.infer_engine = PtEngine.from_model_template(
                self.model, self.template, max_batch_size=self.args.per_device_eval_batch_size)
        self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'predict.jsonl'))
    
    def _freeze_named_param_rows(self, model):
        freeze_global_top_param_ds(model)
        # freeze_rows_in_top_param(model)
        # freeze_param_rows_by_mask_hook(model)
        # freeze_other_param_rows_by_mask_hook(model)
    
    @staticmethod
    def _predict_data_collator(batch):
        return {'_data': batch}
    
    def training_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
    ) -> torch.Tensor:
        model.train()
        if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
            self.optimizer.train()

        inputs = self._prepare_inputs(inputs)
        if is_sagemaker_mp_enabled():
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
            return loss_mb.reduce_mean().detach().to(self.args.device)

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)

        del inputs
        if (
            self.args.torch_empty_cache_steps is not None
            and self.state.global_step % self.args.torch_empty_cache_steps == 0
        ):
            if is_torch_xpu_available():
                torch.xpu.empty_cache()
            elif is_torch_mlu_available():
                torch.mlu.empty_cache()
            elif is_torch_musa_available():
                torch.musa.empty_cache()
            elif is_torch_npu_available():
                torch.npu.empty_cache()
            elif is_torch_mps_available(min_version="2.0"):
                torch.mps.empty_cache()
            else:
                torch.cuda.empty_cache()

        kwargs = {}

        # For LOMO optimizers you need to explicitly use the learnign rate
        if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
            kwargs["learning_rate"] = self._get_learning_rate()

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        if self.use_apex:
            from apex import amp
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            self.accelerator.backward(loss, **kwargs)
            # Finally we need to normalize the loss for reporting
            if num_items_in_batch is None:
                return loss.detach() / self.args.gradient_accumulation_steps
            return loss.detach()

    @contextmanager
    def _patch_predict_with_generate(self):
        origin_mode = self.template.mode
        self.template.set_mode('pt')
        is_multimodal = self.model.model_meta.is_multimodal
        origin_data_collator = self.data_collator

        if is_multimodal:
            self.template.remove_post_encode_hook()
        self.data_collator = self._predict_data_collator
        try:
            yield
        finally:
            if is_multimodal:
                self.template.register_post_encode_hook([self.model])
            self.data_collator = origin_data_collator
            self.template.set_mode(origin_mode)

    def evaluate(self, *args, **kwargs):
        context = self._patch_predict_with_generate() if self.args.predict_with_generate else nullcontext()
        with context:
            return super().evaluate(*args, **kwargs)

    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
        **gen_kwargs,
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys)
        from swift.llm import RequestConfig, InferRequest
        data_list = inputs['_data']
        labels_list = [InferRequest.remove_response(data['messages']) for data in data_list]
        resp_list = self.infer_engine.infer(
            data_list,
            RequestConfig(max_tokens=self.model.generation_config.max_new_tokens),
            use_tqdm=False,
            template=self.template)

        response_list = []
        device = self.args.device
        for data, resp, labels in zip(data_list, resp_list, labels_list):
            response = resp.choices[0].message.content
            self.jsonl_writer.append({'response': response, 'labels': labels, **data})
            response_list.append(Serializer.to_tensor(resp.choices[0].message.content).to(device=device))
        labels_list = [Serializer.to_tensor(labels).to(device=device) for labels in labels_list]
        response_list = pad_sequence(response_list, batch_first=True, padding_value=0)
        labels_list = pad_sequence(labels_list, batch_first=True, padding_value=0)
        return None, response_list, labels_list

    def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=None):
        loss_kwargs = {}
        labels = None
        if (self.label_smoother is not None or self.compute_loss_func is not None) and 'labels' in inputs:
            labels = inputs.pop('labels')

        loss_scale = inputs.pop('loss_scale', None)
        if loss_scale is not None:
            loss_kwargs['loss_scale'] = loss_scale

        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is None:
            labels = inputs['labels']
            if num_items_in_batch is not None:
                if getattr(self.args, 'average_tokens_across_devices', False):
                    outputs.loss *= self.accelerator.num_processes
                outputs.loss = outputs.loss * (labels[:, 1:] != -100).sum() / num_items_in_batch

            if isinstance(outputs, dict) and 'loss' not in outputs:
                raise ValueError(
                    'The model did not return a loss from the inputs, only the following keys: '
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}.")
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0]
        else:
            unwrapped_model = self.accelerator.unwrap_model(model)
            if is_peft_available() and isinstance(unwrapped_model, PeftModel):
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            # User-defined compute_loss function
            if self.compute_loss_func is not None:
                loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch, **loss_kwargs)
            elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)

        if self.args.sequence_parallel_size > 1:
            from swift.trainers.xtuner import reduce_xtuner_sequence_parallel_loss
            loss = reduce_xtuner_sequence_parallel_loss(loss, labels)

        if getattr(self.args, 'average_tokens_across_devices', False):
            loss *= self.accelerator.num_processes

        if outputs.logits is not None:
            # In case of Liger
            self._compute_token_acc(outputs, labels)
        return (loss, outputs) if return_outputs else loss

    def _compute_token_acc(self, outputs, labels) -> None:

        acc_steps = self.args.acc_steps
        preds = outputs.logits.argmax(dim=2)
        if self.state.global_step % acc_steps == 0:
            if use_torchacc():
                ta_trim_graph()
                preds = preds.to('cpu')
                labels = labels.to('cpu')
            metrics = compute_acc(
                preds, labels, acc_strategy=self.args.acc_strategy, is_encoder_decoder=self.args.is_encoder_decoder)
            for k, v in metrics.items():
                if k not in self._custom_metrics:
                    self._custom_metrics[k] = MeanMetric(nan_value=None)
                self._custom_metrics[k].update(v)
