import torch
import copy
from trl import SFTTrainer
from transformers import TrainerCallback
from peft import get_peft_model_state_dict, set_peft_model_state_dict

from dev_scripts.inspect_noniid_model import calculate_cosine_similarity_seperate, orthogonal_similarity_dict_1, orthogonal_similarity_dict_2, orthogonal_similarity_dict_3
import pdb

from utils.DP_SFTTrainer import DP_SFTTrainer
from utils.SFTTrainer_gradinspect import LookGradSFTTrainer

def get_fed_local_sft_trainer(script_args, fed_args, model, tokenizer, training_args, local_dataset, formatting_prompts_func, data_collator, global_dict, local_auxiliary, global_auxiliary):
    DP_ARGS = {}
    if fed_args.use_dp:
        DP_ARGS={
            "l2_norm_clip": fed_args.dp_clip_grad_norm,
            "noise_multiplier": fed_args.dp_sigma,
            "microbatch_size": script_args.gradient_accumulation_steps,
            "delta": fed_args.dp_target_delta,
            "warmup_ratio": 0.15
        }
    if fed_args.fed_alg == 'fedprox':
        trainer = SFTTrainerFedProx(
            model=model,
            tokenizer=tokenizer,
            args=training_args,
            # max_seq_length=script_args.seq_length,
            train_dataset=local_dataset,
            formatting_func=formatting_prompts_func,
            data_collator=data_collator,
            global_state=global_dict,
            prox_mu=fed_args.prox_mu,
            dp_args=DP_ARGS,
        )
    elif fed_args.fed_alg == 'scaffold':
        trainer = SFTTrainerSCAFFOLD(
            model=model,
            tokenizer=tokenizer,
            args=training_args,
            # max_seq_length=script_args.seq_length,
            train_dataset=local_dataset,
            formatting_func=formatting_prompts_func,
            data_collator=data_collator,
            global_state=global_dict,
            local_auxiliary=local_auxiliary,
            global_auxiliary=global_auxiliary,
            dp_args=DP_ARGS,
        )
        trainer.add_callback(SCAFFOLD_Callback(trainer.correction, model))
    elif fed_args.fed_alg == 'fedavg_lora':
        trainer = SFTTrainerFedLora(
            model=model,
            tokenizer=tokenizer,
            args=training_args,
            # max_seq_length=script_args.seq_length,
            lora_alpha = script_args.peft_lora_alpha, 
            train_dataset=local_dataset,
            formatting_func=formatting_prompts_func,
            data_collator=data_collator,
            global_auxiliary=global_auxiliary,
            dp_args=DP_ARGS,
        )
    elif fed_args.fed_alg == 'fedela':
        trainer = SFTTrainerFedELA(
            model=model,
            tokenizer=tokenizer,
            args=training_args,
            # max_seq_length=script_args.seq_length,
            lora_alpha = script_args.peft_lora_alpha, 
            train_dataset=local_dataset,
            formatting_func=formatting_prompts_func,
            data_collator=data_collator,
            global_auxiliary=global_auxiliary,
        )
    elif (fed_args.fed_alg in ['fedavg', 'fedfa', 'flora', 'fedsa']) or (fed_args.fed_alg).startswith('local'):
        if fed_args.use_dp:
            trainer = DP_SFTTrainer(
                model=model,
                tokenizer=tokenizer,
                args=training_args,
                # max_seq_length=script_args.seq_length,
                train_dataset=local_dataset,
                formatting_func=formatting_prompts_func,
                data_collator=data_collator,
                dp_args=DP_ARGS,
            )
        else:
            trainer = SFTTrainer(
                model=model,
                tokenizer=tokenizer,
                args=training_args,
                # max_seq_length=script_args.seq_length,
                train_dataset=local_dataset,
                formatting_func=formatting_prompts_func,
                data_collator=data_collator,
            )
    else:
        raise ValueError(f'Unsupported `fed_alg`: {fed_args.fed_alg}')

    trainer.neftune_noise_alpha = None
    return trainer

class SFTTrainerFedProx(DP_SFTTrainer):  # DP_SFTTrainer or SFTTrainer
    def __init__(self, global_state, prox_mu, **kwargs):
        super(SFTTrainerFedProx, self).__init__(**kwargs)
        self.global_state = global_state
        self.mu = prox_mu
    
    def compute_loss(self, model, inputs, return_outputs=False):

        return_values = super(SFTTrainerFedProx, self).compute_loss(model, inputs, return_outputs=return_outputs)

        if return_outputs:
            loss, outputs = return_values
        else:
            loss = return_values

        # Apply FedProx Loss
        for name, param in model.named_parameters():
            name = name.replace(".default", "")     # TODO: May need changes. to accord with peft
            name = name.replace('_module.', '')
            # only trainable parameters
            if not param.requires_grad:
                continue
            else:
                # print("self.global_state keys:", self.global_state.keys())
                loss += self.mu / 2 * torch.norm(param - self.global_state[name]) ** 2

        return (loss, outputs) if return_outputs else loss


class SFTTrainerSCAFFOLD(SFTTrainer):  # DP_SFTTrainer or SFTTrainer
    def __init__(self, global_state, local_auxiliary, global_auxiliary, **kwargs):
        super(SFTTrainerSCAFFOLD, self).__init__(**kwargs)
        self.global_state = global_state
        self.local_auxiliary = {key:tensor.to(self.model.device) for key, tensor in local_auxiliary.items()}
        self.global_auxiliary = global_auxiliary
        self.correction = copy.deepcopy(local_auxiliary)

        for name in self.correction.keys():
            self.correction[name] = self.global_auxiliary[name] - self.local_auxiliary[name]
    
    def get_auxiliary_param(self):
        auxiliary_new_para = copy.deepcopy(self.local_auxiliary)
        auxiliary_delta_para = copy.deepcopy(self.local_auxiliary)
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if not param.requires_grad:
                    continue
                else:
                    name = name.replace(".default", "")
                    name = name.replace('_module.', '')
                    # print("self.global_state keys:", self.global_state.keys())
                    # print("self.correction keys:", self.correction.keys())

                    auxiliary_new_para[name] = (self.global_state[name] - param) / (self.args.max_steps * self.args.learning_rate) - self.correction[name]
                    auxiliary_delta_para[name] = auxiliary_new_para[name] - self.local_auxiliary[name]
        return auxiliary_new_para, auxiliary_delta_para

class SCAFFOLD_Callback(TrainerCallback):
    def __init__(self, correction, model):
        super(SCAFFOLD_Callback, self).__init__()
        self.correction = correction
        self.model = model
    def on_step_end(self, args, state, control, **kwargs):
        model_para = copy.deepcopy(get_peft_model_state_dict(self.model))  # Get the state dict of the Peft model.
        for name in model_para.keys():
            model_para[name] -= args.learning_rate * self.correction[name] 
        set_peft_model_state_dict(self.model, model_para)  # Set the state dict of the Peft model.


class SFTTrainerFedLora(DP_SFTTrainer): # DP_SFTTrainer or SFTTrainer
    def __init__(self, global_auxiliary, lora_alpha, **kwargs):
        super(SFTTrainerFedLora, self).__init__(**kwargs)
        self.device = self.model.device
        self.global_dict = copy.deepcopy(get_peft_model_state_dict(self.model))
        self.global_auxiliary =  {name: param.to(self.device) for name, param in global_auxiliary.items()}
        self.lora_alpha = lora_alpha
    
    def get_project_param(self, model_dict):
        self.global_dict = model_dict
        Y_i = {}
        for name, Omega in self.global_auxiliary.items():
            Y_i[name] = self.global_dict[f"{name}.lora_B.weight"] @ (self.global_dict[f"{name}.lora_A.weight"] @ Omega) # / self.lora_alpha
        
        return Y_i
        
    @staticmethod
    def get_fix_project_param(local_dict, orth_Q):
        P_i = {}
        for name, Q in orth_Q.items():
            P_i[name] = local_dict[f"{name}.lora_A.weight"].T @ (local_dict[f"{name}.lora_B.weight"].T @ Q)
        
        return P_i


class SFTTrainerFedELA(DP_SFTTrainer):
    def __init__(self, global_auxiliary, lora_alpha, **kwargs):
        super().__init__(**kwargs)
        self.device = self.model.device
        self.global_dict = copy.deepcopy(get_peft_model_state_dict(self.model))
        self.global_auxiliary =  {name: param.to(self.device) for name, param in global_auxiliary.items()}
        self.lora_alpha = lora_alpha

    
    def get_project_param(self):
        self.global_dict = copy.deepcopy(get_peft_model_state_dict(self.model))
        Y_i = {}
        for name, Omega in self.global_auxiliary.items():
            Y_i[name] = self.global_dict[f"{name}.lora_A.weight"].T @ (self.global_dict[f"{name}.lora_B.weight"].T @ Omega) # / self.lora_alpha
    
        return Y_i

    @staticmethod
    def get_update(local_dict, global_B, global_auxiliary, test=None):
        device = next(iter(local_dict.values())).device 
        global_auxiliary =  {name: param.to(device) for name, param in global_auxiliary.items()}
        if test:
            global_Q = test
        ref_Q = {}
        ref_P = {}
        ref_S = {}
        ref_U = {}
        ref_V = {}
        for name, Omega in global_auxiliary.items():
            Y = local_dict[f"{name}.lora_A.weight"].T @ (global_B[f"{name}.lora_B.weight"].T @ Omega)
            Q, _ = torch.linalg.qr(Y)
            P = global_B[f"{name}.lora_B.weight"] @ (local_dict[f"{name}.lora_A.weight"] @ Q)
            U, S, Vh = torch.linalg.svd(P.T, full_matrices=False)
            local_dict[f"{name}.lora_A.weight"] = (Q @ (U * torch.sqrt(S))).T
            local_dict[f"{name}.lora_B.weight"] = (torch.sqrt(S).unsqueeze(1) * Vh).T
            
            if test:
                ref_Q[name] = Q
                ref_P[name] = P
                ref_S[name] = S
                ref_U[name] = U
                ref_V[name] = Vh


        return local_dict,ref_Q, ref_P, ref_S, ref_U, ref_V
    


