import copy
import os
from tqdm import tqdm
import numpy as np
import torch
import random
import pdb

from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DataCollatorForCompletionOnlyLM
from peft import PeftModel, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict, prepare_model_for_kbit_training

from torch.utils.data import DataLoader # 需要 DataLoader
from transformers import get_linear_schedule_with_warmup # 需要 Scheduler
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from opacus import PrivacyEngine
from opacus.data_loader import DPDataLoader

from utils import *
from federated_learning import *
from federated_learning.fed_utils import *
from federated_learning.fed_local_sft import SFTTrainerFedLora, SFTTrainerFedELA
from config import get_config, save_config, get_model_config, get_training_args, load_and_update_config

# test
from dev_scripts.inspect_noniid_model import calculate_cosine_similarity_seperate, orthogonal_similarity_dict_1, orthogonal_similarity_dict_2, orthogonal_similarity_dict_3

# ===== Define the arguments =====
script_args, fed_args, peft_config = get_config()

training_args = get_training_args(script_args, script_args.learning_rate)

if script_args.load_from_check_point:
    if not script_args.checkpoint_dir:
        raise ValueError("NO checkpoint dir, please give the checkpoint dir!")

    checkpoint_dir = script_args.checkpoint_dir
    script_args, fed_args, peft_config = load_and_update_config(f"{os.path.dirname(checkpoint_dir)}/args.json", script_args, fed_args)
else:
    save_config(script_args, fed_args)

print(script_args, fed_args)

# ===== Load the dataset =====
dataset = get_dataset(script_args.dataset_name, script_args.local_data_dir)
dataset = process_sft_dataset(script_args.dataset_name, dataset, script_args.dataset_sample)

# ===== Split the dataset into clients =====
local_datasets = split_dataset(fed_args, script_args, dataset)
sample_num_list = [len(local_datasets[i]) for i in range(fed_args.num_clients)]

# ===== Get model config =====
device_map, quantization_config, torch_dtype = get_model_config(script_args)

# ===== Load the checkpoint =====
if script_args.load_from_check_point:
    checkpoint_round = int(checkpoint_dir.split('-')[-1])
    print(f"loading from {checkpoint_dir}, round {checkpoint_round}")
    # Load the base model
    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        quantization_config=quantization_config,
        device_map=device_map,
        trust_remote_code=script_args.trust_remote_code,
        torch_dtype=torch_dtype,
    )
    if script_args.load_in_8bit or script_args.load_in_4bit:
        model = prepare_model_for_kbit_training(
            model, use_gradient_checkpointing=training_args.gradient_checkpointing
        )
    # Load PEFT model
    print(f"Load global LoRA weights")
    model = PeftModel.from_pretrained(model, checkpoint_dir, is_trainable=True)
    model.print_trainable_parameters()
    if fed_args.fed_alg == 'fedfa':
        freeze_lora_A(model)
        model.print_trainable_parameters()
    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
    if training_args.gradient_checkpointing:
        model.enable_input_require_grads()
    
    # Load local LoRA dictionaries
    print(f"Load local LoRA dictionaries")
    local_lora_dict = torch.load(os.path.join(checkpoint_dir, "local_lora_dict_list.pt"), map_location='cpu')
    local_dict_list = local_lora_dict
    # Load training loss if needed
    training_loss = np.load(os.path.join(script_args.output_dir, "training_loss.npy")).tolist()
    # Set starting round
    starting_round = checkpoint_round
else:
    # ===== Load the base model =====
    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        quantization_config=quantization_config,
        device_map='auto',
        trust_remote_code=script_args.trust_remote_code,
        torch_dtype=torch_dtype,
    )
    if script_args.load_in_8bit or script_args.load_in_4bit:
        model = prepare_model_for_kbit_training(
            model, use_gradient_checkpointing=training_args.gradient_checkpointing
        )

    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    if fed_args.fed_alg=='fedfa' or (fed_args.fed_alg=='fedavg_lora' and fed_args.use_dp):
        freeze_lora_A(model)
        model.print_trainable_parameters()
    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
    if training_args.gradient_checkpointing:
        model.enable_input_require_grads()
    # Initialize global and local dictionaries
    global_dict = copy.deepcopy(get_peft_model_state_dict(model))
    local_dict_list = [copy.deepcopy(global_dict) for i in range(fed_args.num_clients)]
    training_loss = [[] for i in range(fed_args.num_clients)]
    starting_round = 0

# ===== Prepare auxiliary variables if needed =====
proxy_dict, opt_proxy_dict = get_proxy_dict(fed_args, global_dict)
global_auxiliary, auxiliary_model_list, auxiliary_delta_dict = get_auxiliary_dict(fed_args, global_dict)

# ===== Define the tokenizer =====
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, use_fast=False, padding_side="right")
if tokenizer.pad_token is None:
    # tokenizer.pad_token = tokenizer.unk_token # llama 2
    tokenizer.pad_token = tokenizer.eos_token  # llama 3[]
    # tokenizer.pad_token = '<|finetune_right_pad_id|>' # llama 3.1

# ===== Define the formatting function (cater to TRL SFTTrainer)=====
formatting_prompts_func, response_template = get_formatting_prompts_func(script_args.template, tokenizer.eos_token)
# pdb.set_trace()
# response_template_ids = tokenizer.encode(response_template, add_special_tokens=False)[2:]   # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]` for Llama2   
#[13, 7900, 22137, 29901]  
# response_template: '\n\nAssistant:' response_template_ids: [29871, 13, 13, 7900, 22137, 29901]
# response_template: '\n\nAssistant:\n' response_template_ids: [13, 13, 7900, 22137, 29901, 13]
response_template_ids = tokenizer.encode(response_template, add_special_tokens=False)[1:]
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)

# ===== Start federated training =====
for round in tqdm(range(starting_round, fed_args.num_rounds)):
    clients_this_round = get_clients_this_round(fed_args, round)
    print(f">> ==================== Round {round+1} : {clients_this_round} ====================")
    
    for client in range(fed_args.num_clients):
        if client not in clients_this_round:
            training_loss[client].append(-1)            # -1 is an indicator of not training
            continue
        if fed_args.fed_alg != "fedela":
            # Set the local model to the current global state
            set_peft_model_state_dict(model, global_dict)
        else:
            set_peft_model_state_dict(model, local_dict_list[client])
        
        if fed_args.fed_alg == "fedavg_lora":
            freeze_lora_A(model)
        
        sub_dataset = get_dataset_this_round(local_datasets[client], round, fed_args, script_args)      # get the required sub-dataset for this round
        new_lr = cosine_learning_rate(round, fed_args.num_rounds, script_args.learning_rate, 1e-6)      # manually schedule the learning rate
        training_args = get_training_args(script_args, new_lr)

        # ===== Train local model on the client side =====
        trainer = get_fed_local_sft_trainer(
            model=model,
            tokenizer=tokenizer,
            training_args=training_args,
            local_dataset=sub_dataset,
            formatting_prompts_func=formatting_prompts_func,
            data_collator=data_collator,
            global_dict=global_dict,
            fed_args=fed_args,
            script_args=script_args,
            local_auxiliary=auxiliary_model_list[client],
            global_auxiliary=global_auxiliary,
        )

        # trainer.neftune_noise_alpha = None
        results = trainer.train()
        #pdb.set_trace()
        training_loss[client].append(results.training_loss)
        local_dict_list[client] = copy.deepcopy(get_peft_model_state_dict(model))

        # ===== Client transmits local information to server =====
        if fed_args.fed_alg == 'scaffold':
            auxiliary_model_list[client], auxiliary_delta_dict[client] = trainer.get_auxiliary_param()
        elif fed_args.fed_alg == 'fedavg_lora':
            auxiliary_delta_dict[client] = trainer.get_project_param(local_dict_list[client])
        elif fed_args.fed_alg == 'fedela':
            auxiliary_delta_dict[client] = trainer.get_project_param()
           # copy is needed!

    # ===== fedlora_avg implementation==============
    if fed_args.fed_alg == 'fedavg_lora':
        auxiliary_info = (global_auxiliary, auxiliary_delta_dict, "comm_1", script_args.peft_lora_r)
        _, global_Q = global_aggregate(
            fed_args, global_dict, local_dict_list, sample_num_list, \
                clients_this_round, round, proxy_dict=proxy_dict, \
                opt_proxy_dict=opt_proxy_dict, auxiliary_info=auxiliary_info
        )

        for client in range(fed_args.num_clients):
            if client not in clients_this_round:
                continue
            
            auxiliary_delta_dict[client] = SFTTrainerFedLora.get_fix_project_param(local_dict_list[client],  global_Q)

        auxiliary_info = (global_Q, auxiliary_delta_dict, "comm_2", script_args.peft_lora_r)
        global_dict, _ = global_aggregate(
                fed_args, global_dict, local_dict_list, sample_num_list, \
                    clients_this_round, round, proxy_dict=proxy_dict, \
                    opt_proxy_dict=opt_proxy_dict, auxiliary_info=auxiliary_info
            )
    
    elif fed_args.fed_alg == 'fedela':
        global_Q = {}
        for key in global_auxiliary.keys():
            Y = sum([auxiliary_delta_dict[client][key] * sample_num_list[client] / len(clients_this_round) for client in clients_this_round])
            global_Q[key], _ = torch.linalg.qr(Y)

        _, global_B = global_aggregate(
            fed_args, global_dict, local_dict_list, sample_num_list, \
            clients_this_round, round, proxy_dict=proxy_dict, \
            opt_proxy_dict=opt_proxy_dict, auxiliary_info=(global_auxiliary, auxiliary_delta_dict)
        )
        ref_Q = []
        ref_P = []
        ref_S = []
        ref_U = []
        ref_V = []
        ref_dict = []
        for client in range(fed_args.num_clients):
            if client not in clients_this_round:
                continue
            local_dict_list[client], Q,P,S,U,V = SFTTrainerFedELA.get_update(local_dict_list[client], global_B, global_auxiliary, test = global_Q)
            ref_Q.append(Q)
            ref_P.append(P)
            ref_S.append(S)
            ref_U.append(U)
            ref_V.append(V)
        
        # 衡量不同client之间的相似程度
        res_sim = calculate_cosine_similarity_seperate(local_dict_list[clients_this_round[0]], local_dict_list[clients_this_round[1]])
        
        global_dict = local_dict_list[random.choice(clients_this_round)]

        ##pdb.set_trace()
    else:
        # ===== Server aggregates the local models =====
        global_dict, global_auxiliary = global_aggregate(
            fed_args, global_dict, local_dict_list, sample_num_list, \
            clients_this_round, round, proxy_dict=proxy_dict, \
            opt_proxy_dict=opt_proxy_dict, auxiliary_info=(global_auxiliary, auxiliary_delta_dict)
        )
        # pdb.set_trace()

        if fed_args.fed_alg == 'flora':
            model = model.merge_and_unload()
            stack_peft_config = copy.deepcopy(peft_config)
            stack_peft_config.r *= len(clients_this_round)
            model = get_peft_model(model, stack_peft_config)
            set_peft_model_state_dict(model, global_dict)
            ##pdb.set_trace()
            save_path = os.path.join(script_args.output_dir, f"checkpoint-{round+1}")
            model = model.merge_and_unload()
            model.save_pretrained(save_path)
            tokenizer.save_pretrained(save_path)
            print(f"checkpoint saved at{save_path}")

            model = get_peft_model(model, peft_config)
            model.print_trainable_parameters()
            ##pdb.set_trace()
            model.config.use_cache = False 
            if training_args.gradient_checkpointing:
                model.enable_input_require_grads()
            global_dict = copy.deepcopy(get_peft_model_state_dict(model))
            local_dict_list = [copy.deepcopy(global_dict) for i in range(fed_args.num_clients)]

    set_peft_model_state_dict(model, global_dict)   # Update global model

    # ===== Save the model =====
    if (round+1) % fed_args.save_model_freq == 0 or (round+1) == fed_args.num_rounds:
        if fed_args.fed_alg == 'flora':
            pass
        else:
            save_path = os.path.join(script_args.output_dir, f"checkpoint-{round+1}")
            model.save_pretrained(save_path)
            # trainer.save_model(save_path)
            torch.save(local_dict_list, os.path.join(save_path, f"local_lora_dict_list.pt"))
        print(f"checkpoint saved at{save_path}")
    
    np.save(os.path.join(script_args.output_dir, "training_loss.npy"), np.array(training_loss))


if fed_args.fed_alg == 'flora':
    save_path = os.path.join(script_args.output_dir, f"checkpoint-{round+1}")
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)
else:
    save_path = os.path.join(script_args.output_dir, f"checkpoint-{round+1}")
    model.save_pretrained(save_path)
    # trainer.save_model(save_path)
    torch.save(local_dict_list, os.path.join(save_path, f"local_lora_dict_list.pt"))
print(f"checkpoint saved at{save_path}")
