import copy
import os
from tqdm import tqdm
import numpy as np

from transformers import AutoModelForCausalLM, AutoTokenizer

from trl import DataCollatorForCompletionOnlyLM
from peft import get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict, prepare_model_for_kbit_training

from utils import *

from utils.process_dataset import get_local_dataset, modified_process_sft_data

from federated_learning import *
from config import get_config, save_config, get_model_config, get_training_args

import torch

from trl import SFTTrainer


from prune import generator
from prune.pruners import *

from collections import OrderedDict


# ===== Define the arguments =====
script_args, fed_args, peft_config = get_config()
training_args = get_training_args(script_args, script_args.learning_rate)
save_config(script_args, fed_args)
print(script_args, fed_args)

seed = script_args.seed

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# np.random.seed(seed)
# random.seed(seed)
torch.backends.cudnn.deterministic = True


# ===== Load the local dataset =====
dataset = get_local_dataset(script_args.dataset_name)
dataset, remain_dataset = modified_process_sft_data(script_args.dataset_name, dataset, script_args.dataset_sample)

# # ===== Load the dataset =====
# dataset = get_dataset(script_args.dataset_name, script_args.local_data_dir)
# dataset, remain_dataset = process_sft_dataset(script_args.dataset_name, dataset, script_args.dataset_sample)

# ===== Split the dataset into clients =====
local_datasets = split_dataset(fed_args, script_args, dataset)
sample_num_list = [len(local_datasets[i]) for i in range(fed_args.num_clients)]

# ===== Get model config =====
device_map, quantization_config, torch_dtype = get_model_config(script_args)

model = AutoModelForCausalLM.from_pretrained(
    script_args.model_name_or_path,
    quantization_config=quantization_config,
    device_map=device_map,
    trust_remote_code=script_args.trust_remote_code,
    torch_dtype=torch_dtype,
)

if script_args.load_in_8bit or script_args.load_in_4bit:
    #         1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
    #         head to fp32
    model = prepare_model_for_kbit_training(
        model, use_gradient_checkpointing=training_args.gradient_checkpointing
    )

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# ===== Define the global and local models =====
global_dict_w_mask = copy.deepcopy(get_peft_model_state_dict(model))

# lora_state_dict = get_peft_model_state_dict(model)
global_dict = OrderedDict((k, copy.deepcopy(v)) for k, v in global_dict_w_mask.items() if 'mask' not in k)



local_dict_list = [copy.deepcopy(global_dict) for i in range(fed_args.num_clients)]
proxy_dict, opt_proxy_dict = get_proxy_dict(fed_args, global_dict)
global_auxiliary, auxiliary_model_list, auxiliary_delta_dict = get_auxiliary_dict(fed_args, global_dict)

# ===== Define the tokenizer =====
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, use_fast=False, padding_side="right")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.unk_token  # following vicuna

# ===== Define the formatting function (cater to TRL SFTTrainer)=====
formatting_prompts_func, response_template = get_formatting_prompts_func(script_args.template, tokenizer.eos_token)
response_template_ids = tokenizer.encode(response_template, add_special_tokens=False)[
                        2:]  # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]` for Llama2
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)


# """
# test global moder eval
# """
def global_eval(g_model, global_eval_dataset):
    # training_args.train_batch_size = 16

    # g_model.eval()

    g_trainer = SFTTrainer(
        model=g_model,
        tokenizer=tokenizer,
        args=training_args,
        max_seq_length=script_args.seq_length,
        train_dataset=None,
        eval_dataset=global_eval_dataset,
        formatting_func=formatting_prompts_func,
        data_collator=data_collator,
    )


    glo_eval_result = g_trainer.evaluate()

    del g_trainer

    torch.cuda.empty_cache()

    # Delete g_trainer

    return glo_eval_result['eval_loss']



def per_prune(model, local_data_loader, client_id, p_ep, sparsity, ori_mask_state_dict):
    # if hetero_prune:
    #     sparsity = sp_list[client_id]
    # else:
    #     sparsity = obj_sparse

    loc_trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        max_seq_length=script_args.seq_length,
        train_dataset=local_data_loader,
        eval_dataset=None,
        formatting_func=formatting_prompts_func,
        data_collator=data_collator,
    )

    train_dataloader = loc_trainer.get_train_dataloader()

    pruner = IterSNIP(generator.masked_parameters(model))

    schedule = 'exponential'
    scope = 'global'
    pruning_epochs = p_ep

    for epoch in tqdm(range(pruning_epochs)):

        if sparsity == 1.0:
            break
        else:
            # for step, inputs in enumerate(train_dataloader):
            pruner.score_llm(loc_trainer, model, train_dataloader)
            # pruner.sec_score_llm(client, model, train_dataloader)
            # pruner.score_llm_AB_linking(client, models[client_id], train_dataloader)
            if schedule == 'exponential':
                sparse = sparsity ** ((epoch + 1) / pruning_epochs)
            elif schedule == 'linear':
                sparse = 1.0 - (1.0 - sparsity) * ((epoch + 1) / pruning_epochs)
            elif schedule == 'expinv':
                sparse = 1.0 - (1.0 - sparsity) ** (pruning_epochs / (epoch + 1))
            pruner.mask(sparse, scope)
            # pruner._global_mask_AB_linking(sparse)

    # Confirm sparsity level
    remaining_params, total_params = pruner.stats()
    print('{}/{}'.format(remaining_params, total_params))

    current_mask_state_dict = OrderedDict()
    for name, buffer in model.named_buffers():
        if 'weight_mask' in name:
            # ori_mask_state_dict[name] = torch.zeros_like(buffer)
            current_mask_state_dict[name] = buffer

    output_dir_for_mask = os.path.join('.', 'mask')

    # output_dir = os.path.join(output_dir, str(num_clients))

    if not os.path.exists(output_dir_for_mask):
        os.makedirs(output_dir_for_mask)

    mask_saving_dir = os.path.join(output_dir_for_mask, "client_{}.bin".format(client_id))

    # save the current pruned mask
    torch.save(current_mask_state_dict, mask_saving_dir)

    # reload the ori_mask for pruning on the next client
    model.load_state_dict(ori_mask_state_dict, strict=False)



def prune_loop(model,  data_list, p_ep=1, obj_sparse=0.5, hetero_prune=False):
    from collections import OrderedDict
    mask_state_dict = OrderedDict()
    for name, buffer in model.named_buffers():
        if 'mask' in name:
            # ori_mask_state_dict[name] = torch.zeros_like(buffer)
            mask_state_dict[name] = buffer

    ori_mask_state_dict = copy.deepcopy(mask_state_dict)

    """heterogeneous masks"""
    if hetero_prune:
        sp_list = []
        for i in range(fed_args.num_clients):
            if i < 33:
                sp_list.append(0.75)
            elif 32 <= i < 66:
                sp_list.append(0.5)
            else:
                sp_list.append(1.0)


    for client_id in range(fed_args.num_clients):
        if hetero_prune:
            print('Current sp', sp_list[client_id])
            per_prune(model, data_list[client_id], client_id, p_ep, sp_list[client_id], ori_mask_state_dict)
        else:
            per_prune(model, data_list[client_id], client_id, p_ep, obj_sparse, ori_mask_state_dict)



# ===== Start federated pruning =====

prune = False

prune_has_mask = False

if prune:
    prune_loop(model, local_datasets, p_ep=5, obj_sparse=0.666, hetero_prune=True)



# ===== Start federated training =====
training_loss = [[] for i in range(fed_args.num_clients)]
global_loss = []
if fed_args.per_tuning:
    avg_local_eval_loss = []

for round in tqdm(range(fed_args.num_rounds)):

    clients_this_round = get_clients_this_round(fed_args, round)

    # eval_loss = [[] for i in range(len(clients_this_round))]

    eval_loss = []



    print(f">> ==================== Round {round + 1} : {clients_this_round} ====================")

    for client in range(fed_args.num_clients):

        # lora_state_dict = get_peft_model_state_dict(model)
        # non_mask_state_dict = OrderedDict((k, copy.deepcopy(v)) for k, v in lora_state_dict.items() if 'mask' not in k)



        if client not in clients_this_round:
            training_loss[client].append(-1)  # -1 is an indicator of not training
            continue

        set_peft_model_state_dict(model, global_dict)  # sync the global model to the local model

        if prune_has_mask:
            # mask, adj_mask

            output_dir_for_mask = os.path.join('.', 'mask')

            # load the client's mask
            mask_saving_dir = os.path.join(output_dir_for_mask, "client_{}.bin".format(client))
            current_mask = torch.load(mask_saving_dir)
            model.load_state_dict(current_mask, strict=False)

            # set the parameters in the mask to be 1 bit
            for name, buffer in model.named_buffers():
                if 'mask' in name:
                    buffer.data = buffer.data.byte()




        # sub_dataset = get_dataset_this_round(local_datasets[client], round, fed_args,
        #                                      script_args)  # get the required sub-dataset for this round

        if fed_args.per_tuning:
            eval_size = int(fed_args.per_eval_ratio * local_datasets[client].num_rows)
            split_data = local_datasets[client].train_test_split(test_size=eval_size, shuffle=False, seed=script_args.seed)

            training_data = split_data['train']
            eval_data = split_data['test']

        else:
            training_data = local_datasets[client]
            eval_data = None

        new_lr = cosine_learning_rate(round, fed_args.num_rounds, script_args.learning_rate,
                                      1e-6)  # manually schedule the learning rate
        training_args = get_training_args(script_args, new_lr)

        # ===== Train local model on the client side =====
        trainer = get_fed_local_sft_trainer(
            model=model,
            tokenizer=tokenizer,
            training_args=training_args,
            local_train_dataset=training_data,
            local_eval_dataset=eval_data,
            formatting_prompts_func=formatting_prompts_func,
            data_collator=data_collator,
            global_dict=global_dict,
            fed_args=fed_args,
            script_args=script_args,
            local_auxiliary=auxiliary_model_list[client],
            global_auxiliary=global_auxiliary,
        )

        results = trainer.train()
        training_loss[client].append(results.training_loss)

        if fed_args.per_tuning:
            eval_result = trainer.evaluate()
            eval_loss.append(eval_result['eval_loss'])

        # ===== Client transmits local information to server =====
        if fed_args.fed_alg == 'scaffold':
            auxiliary_model_list[client], auxiliary_delta_dict[client] = trainer.get_auxiliary_param()

        # ===== Rewrite state_dict ===============
        params_dict_non_mask = OrderedDict((name, param.detach()) for name, param in model.named_parameters() if
                                           "default" in name and 'mask' not in name)
        model.state_dict = (
            lambda instance, *_, **__: get_peft_model_state_dict(
                instance, params_dict_non_mask, "default"
            )
        ).__get__(model, type(model))

        local_dict_list[client] = copy.deepcopy(model.state_dict())  # deep copy is needed!

    # avg of local loss for this current round and save
    if fed_args.per_tuning:
        avg_local_eval_loss.append(np.mean(eval_loss))
        print('Local Avg Loss:',avg_local_eval_loss[-1])

    # ===== Server aggregates the local models =====
    global_dict, global_auxiliary = global_aggregate(
        fed_args, global_dict, local_dict_list, sample_num_list, \
        clients_this_round, round, proxy_dict=proxy_dict, \
        opt_proxy_dict=opt_proxy_dict, auxiliary_info=(global_auxiliary, auxiliary_delta_dict), overall_drop_rate=0.5
    )
    set_peft_model_state_dict(model, global_dict)  # Update global model


    """
    global evaluation
    """
    # glo_eval_loss = global_eval(model, remain_dataset)
    # print('Global eval_loss', glo_eval_loss)
    # global_loss.append(glo_eval_loss)

    # ===== Save the model =====
    # currently save the final global model for further evaluation
    # if (round + 1) % fed_args.num_rounds == 0:
    trainer.save_model(os.path.join(script_args.output_dir, f"checkpoint-{round + 1}"))

    np.save(os.path.join(script_args.output_dir, "training_loss.npy"), np.array(training_loss))
    np.save(os.path.join(script_args.output_dir, "global_loss.npy"), np.array(global_loss))
    if fed_args.per_tuning:
        np.save(os.path.join(script_args.output_dir, "avg_local_eval_loss.npy"), np.array(avg_local_eval_loss))




