from transformers.modeling_outputs import CausalLMOutputWithPast
from lorra_finetune.src.args import LorraArguments,TrainingArguments
from lorra_finetune.src.lorra_eng import LoRRaEngTrainer
from tmt.bean import DefaultArgument
from tmt import proc
import torch
import gc


class DirectModTrainer(LoRRaEngTrainer):

    def __init__(self, model, tokenizer, lorra_args: LorraArguments, args: TrainingArguments,
                 train_dataset,val_dataset, def_args: DefaultArgument):
        super().__init__(
            model=model, tokenizer=tokenizer, lorra_args=lorra_args,
            args=args, train_dataset=train_dataset, val_datasets=val_dataset
        )
        self.alpha_pos=def_args.alpha_pos
        self.beta_neg=def_args.beta_neg


    def compute_loss(self, model, inputs, return_outputs=False):

        model.eval()
        with torch.no_grad():
            model.set_adapter("pos")
            pos_output: CausalLMOutputWithPast = model(**inputs, output_hidden_states=True)
            pos_hidden_states = pos_output.hidden_states
            pos_rep_hidden = [pos_hidden_states[l][:, -self.min_length:].detach() for l in self.lora_target_layers]
            pos_rep_hidden = torch.stack(pos_rep_hidden)

            model.set_adapter("neg")
            neg_output: CausalLMOutputWithPast = model(**inputs, output_hidden_states=True)
            neg_hidden_states = neg_output.hidden_states
            neg_rep_hidden = [ neg_hidden_states[l][:, -self.min_length:].detach() for l in self.lora_target_layers ]
            neg_rep_hidden = torch.stack(neg_rep_hidden)

            gc.collect()
            torch.cuda.empty_cache()

        model.set_adapter("center")
        model.train()
        outputs: CausalLMOutputWithPast = model(**inputs, output_hidden_states=True)

        hidden_states = outputs.hidden_states
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        rep_hidden = torch.stack([hidden_states[l][:, -self.min_length:] for l in self.lora_target_layers])
        loss_pos = torch.norm(rep_hidden-pos_rep_hidden, dim=-1, p=2).nanmean()
        loss_neg = torch.norm(rep_hidden-neg_rep_hidden, dim=-1, p=2).nanmean()

        loss = loss + self.alpha_pos*loss_pos - self.beta_neg*loss_neg

        return (loss, outputs) if return_outputs else loss

    def save_peft(self):
        self.model.save_pretrained(f"{self.training_args.output_dir}/{self.training_args.run_name}_{self.model_idx}",
                                   selected_adapters=["center"], safe_serialization=False)

    def evaluate(self, eval_dataset=None, ignore_keys=None, sanity_check=False, **kwargs):
        self.model.set_adapter("center")
        return super().evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys, sanity_check=sanity_check, **kwargs)


def main():
    proc.fine_tune_proc(proc.dolly_fmt_fun, DirectModTrainer)


if __name__ == "__main__":
    main()