from transformers.modeling_outputs import CausalLMOutputWithPast
from lorra_finetune.src.args import LorraArguments,TrainingArguments
from tmt.bean import DefaultArgument
from tmt.direct_mod_trainer import DirectModTrainer
from tmt import proc
import torch
class DirectLorraTrainer(DirectModTrainer):

    def __init__(self, model, tokenizer, lorra_args: LorraArguments, args: TrainingArguments,
                 train_dataset,val_dataset, def_args:DefaultArgument):
        super().__init__(
            model=model, tokenizer=tokenizer,lorra_args=lorra_args,
            args=args, train_dataset=train_dataset, val_dataset=val_dataset,
            def_args=def_args
        )

    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.get("input_ids")
        attention_mask = inputs.get("attention_mask")
        assert input_ids.shape[1] == 3
        orig_input_ids = input_ids[:, 0]
        orig_attention_mask = attention_mask[:, 0]

        model.eval()
        with torch.no_grad():
            model.set_adapter("pos")
            pos_output: CausalLMOutputWithPast = model(
                    input_ids=orig_input_ids,
                    attention_mask=orig_attention_mask,
                    output_hidden_states=True
            )
            pos_hidden_states = pos_output.hidden_states
            pos_rep_hidden = [pos_hidden_states[l][:, -self.min_length:].detach() for l in self.lora_target_layers]
            pos_rep_hidden = torch.stack(pos_rep_hidden)

            model.set_adapter("neg")
            neg_output: CausalLMOutputWithPast = model(
                    input_ids=orig_input_ids,
                    attention_mask=orig_attention_mask,
                    output_hidden_states=True)
            neg_hidden_states = neg_output.hidden_states
            neg_rep_hidden = [ neg_hidden_states[l][:, -self.min_length:].detach() for l in self.lora_target_layers ]
            neg_rep_hidden = torch.stack(neg_rep_hidden)

        model.set_adapter("center")
        # model.train()
        loss, rep_hidden = super(DirectModTrainer,self).compute_loss(model, inputs, return_outputs=True)

        loss_pos = torch.norm(rep_hidden-pos_rep_hidden, dim=-1, p=2).nanmean()
        loss_neg = torch.norm(rep_hidden-neg_rep_hidden, dim=-1, p=2).nanmean()

        loss = loss + self.alpha_pos*loss_pos - self.beta_neg*loss_neg

        return (loss, rep_hidden) if return_outputs else loss

def main():
    proc.fine_tune_proc(proc.lorra_fmt_fun, DirectLorraTrainer)


if __name__ == "__main__":
    main()