import tqdm
# import generator

from pruners import *
from fed_utils import *

prune = False

if prune:

    """=============================================== pruning stage, comment it if you dont need=========================================="""
    # lora rank search by pruning
    num_clients = 2

    for client_id in range(num_clients):

        client = GeneralClient(client_id, models[client_id], data_path, output_dir)
        # selfwrite_end
        print("\nPreparing the local dataset and trainer for Client_{} pruning".format(client_id))
        client.preprare_local_dataset(generate_and_tokenize_prompt, local_val_set_size)
        client.build_local_trainer(tokenizer,
                                   local_micro_batch_size,
                                   gradient_accumulation_steps,
                                   local_num_epochs,
                                   local_learning_rate,
                                   group_by_length,
                                   ddp)

        print("Initiating the local Pruning of Client_{}".format(client_id))
        client.initiate_local_training()

        # client.local_trainer.get_train_dataloader()

        train_dataloader = client.local_trainer.get_train_dataloader()

        pruner = IterSNIP(generator.masked_parameters(models[client_id]))
        sparsity = 0.8 ** (float(2))
        schedule = 'exponential'
        scope = 'global'
        pruning_epochs = 1

        # prune_loop_llm(models[client_id], pruner, sparsity, schedule, scope, pruning_epochs)

        for epoch in tqdm(range(pruning_epochs)):
            # for step, inputs in enumerate(train_dataloader):
            pruner.score_llm(client, models[client_id], train_dataloader)
            # pruner.score_llm_AB_linking(client, models[client_id], train_dataloader)
            if schedule == 'exponential':
                sparse = sparsity ** ((epoch + 1) / pruning_epochs)
            elif schedule == 'linear':
                sparse = 1.0 - (1.0 - sparsity) * ((epoch + 1) / pruning_epochs)
            elif schedule == 'expinv':
                sparse = 1.0 - (1.0 - sparsity) ** (pruning_epochs / (epoch + 1))
            pruner.mask(sparse, scope)
            # pruner._global_mask_AB_linking(sparse)

        # Confirm sparsity level
        remaining_params, total_params = pruner.stats()
        print('{}/{}'.format(remaining_params, total_params))

        import math

        for k in models[client_id].state_dict():
            if 'lora_A' in k:
                nn.init.kaiming_uniform_(models[client_id].state_dict()[k], a=math.sqrt(5))
            if 'lora_B' in k:
                nn.init.zeros_(models[client_id].state_dict()[k])
                # print('lora_B dim:', models[client_id].state_dict()[k].size())
                # print(torch.numel(models[client_id].state_dict()[k]))

    """=============================================== pruning stage, comment it if you dont need=========================================="""