import torch
import copy
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import collections
import  math
import datasets
import evaluate as eval_metric
from datasets import  load_metric
from transformers.optimization import get_scheduler
from transformers import DataCollatorWithPadding, default_data_collator
from utils_qa import postprocess_qa_predictions

class CustomDataCollator(DataCollatorWithPadding):
    def __call__(self, features):
        # Remove 'offset_mapping' before collating and handle missing fields
        for feature in features:
            feature.pop('offset_mapping', None)
            for key, value in feature.items():
                if value is None:
                    if key in ['input_ids', 'attention_mask', 'token_type_ids']:
                        feature[key] = [0] * self.tokenizer.model_max_length
                    elif key in ['start_positions', 'end_positions']:
                        feature[key] = 0
                    else:
                        feature[key] = ''
        return super().__call__(features)


def remove_unused_columns(dataset, tokenizer):
    if isinstance(dataset, datasets.Dataset):
        column_names = dataset.column_names
    else:
        column_names = list(dataset[0].keys())

    # Tokenizer columns
    tokenizer_columns = list(tokenizer.model_input_names)
    tokenizer_columns.extend(['start_positions', 'end_positions', 'id', 'context'])
    # tokenizer_columns.extend(['start_positions', 'end_positions'])

    # Remove columns
    columns_to_remove = [col for col in column_names if col not in tokenizer_columns]
    dataset = dataset.remove_columns(columns_to_remove)
    return dataset

def create_dataloader(dataset, tokenizer, batch_size, shuffle=False):
    dataset = remove_unused_columns(dataset, tokenizer)
    data_collator = CustomDataCollator(tokenizer=tokenizer)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=data_collator)
    return dataloader

def create_scheduler(args, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
    """
    Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
    passed as an argument.
    Args:
        num_training_steps (int): The number of training steps to do.
    """
    lr_scheduler = get_scheduler(
        args.lr_scheduler_type,
        optimizer=optimizer if optimizer is None else optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
        scheduler_specific_kwargs={},
    )
    return lr_scheduler

def modify_param(module, param_name, gamma, A=True):
    if hasattr(module, param_name):
        param = getattr(module, param_name)
        print(f'{param_name}: {param.size()}')
        print(f'{param_name}: {param.size()}')
        print(f'{param_name}: {param.size()}')

        new_param = torch.nn.Parameter(param.data[:math.floor(param.size(0) * gamma), :]) if A else  torch.nn.Parameter(param.data[:, :math.floor(param.size(1) *gamma)])
        setattr(module, param_name, new_param)
        return new_param
    else:
        for name, child in module.named_children():
            new_param = modify_param(child, param_name, gamma, A)
            if new_param is not None:
                return new_param
    return None

class Model():
    def __init__(self, args, model, dist):
        super(Model, self).__init__()
        self.args = args
        self.model = model
        self.dist = dist
        self.current_rank = self.args.rank_mat
        self.create_opt()
        self.past_regularization = self.regularization()
        # self.layered_rank = [self.args.rank_mat for name, param in self.model.named_parameters() if 'lora_A' in name or 'lora_B' in name]

    def regularization(self):
        # self.trainable_params_A = []
        # self.trainable_params_B = []
        # self.trainable_params_C = []
        # for name, param in self.model.named_parameters():
        #     if param.requires_grad:
        #         if 'lora_A' in name:
        #             self.trainable_params_A.append(param)
        #         elif 'lora_B' in name:
        #             self.trainable_params_B.append(param)
        #         else:
        #             self.trainable_params_C.append(param)
        x = 0.0
        for (b, a) in zip(self.trainable_params_B, self.trainable_params_A):
            x += (a[math.floor(a.size(0) * self.args.gamma):a.size(0), :]).norm() *\
                 (b[:, math.floor(b.size(1) * self.args.gamma):b.size(1)]).norm()
        return x

    def create_opt(self):
        self.trainable_params_A = []
        self.trainable_params_B = []
        self.trainable_params_C = []
        num_trainable_params = 0
        all_param = 0
        for name, param in self.model.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel
            # Due to the design of 4bit linear layers from bitsandbytes
            # one needs to multiply the number of parameters by 2 to get
            # the correct number of parameters
            if param.__class__.__name__ == "Params4bit":
                num_params = num_params * 2

            all_param += num_params

            if param.requires_grad:
                if 'lora_A' in name :
                    self.trainable_params_A.append(param)
                elif 'lora_B' in name:
                    self.trainable_params_B.append(param)
                else:
                    self.trainable_params_C.append(param)
                num_trainable_params += num_params

        num_trainable_params = torch.tensor([num_trainable_params]).cuda()
        print(f'{self.args.rank}: {num_trainable_params}')
        self.dist.reduce(num_trainable_params, dst=0, op=self.dist.ReduceOp.SUM)
        print('averaging...')
        num_trainable_params_avg = int(num_trainable_params[0]/self.args.world_size)
        
        if self.args.rank % self.args.world_size == 0:
            print(f"trainable params: {num_trainable_params_avg:,d} || all params: {all_param:,d} || trainable%: {100 * num_trainable_params / all_param}")

    def start_local_steps(self):
        # get the regularization at the beginning of each local steps.
        self.past_regularization = self.regularization()
        # optimizer_grouped_parameters = [
        #     {
        #         "params": self.trainable_params_A,
        #         "weight_decay": 1e-4,
        #         "lr":  self.optimizer.param_groups[0]['lr'],
        #     },
        #     {
        #         "params": self.trainable_params_B,
        #         "weight_decay": 1e-4,
        #         "lr":  self.optimizer.param_groups[1]['lr'],
        #     },
        #     {
        #         "params": self.trainable_params_C,
        #         "weight_decay": 1e-4,
        #         "lr":  self.optimizer.param_groups[2]['lr'],
        #     },
        # ]
        # new_optimizer = torch.optim.AdamW(optimizer_grouped_parameters)
        # new_optimizer.state_dict()['state'] = self.optimizer.state_dict()['state']
        # self.optimizer = new_optimizer

    def end_local_step(self):
        x = self.regularization()
        print(f'past_regularization: {self.past_regularization}')
        print(f'cur_regularization: {x}')
        i, j = 0, 0
        # prune the lora_A and lora_B if the regularization becomes smaller
        if  self.past_regularization - x >  0 and self.current_rank > self.args.rank_min:
            print('pruning')
            # self.trainable_params_A = []
            # self.trainable_params_B = []
            for name, module in self.model.named_modules():
                if isinstance(module, nn.Linear):
                    if 'lora_A' in name:
                        self.current_rank = math.floor(module.weight.size(0) * self.args.gamma)
                        new_layer =  torch.nn.Parameter(module.weight.data[:self.current_rank, :])
                        setattr(module, 'weight', new_layer)
                        self.trainable_params_A[i] = new_layer.data
                        i += 1
                    if 'lora_B' in name:
                        self.current_rank = math.floor(module.weight.size(1) * self.args.gamma)
                        new_layer =  torch.nn.Parameter(module.weight.data[:, :self.current_rank])
                        setattr(module, 'weight', new_layer)
                        self.trainable_params_B[j] = new_layer.data
                        self.current_rank = new_layer.data.size(1)
                        j += 1
        print(f'rank: {self.args.rank}, r: {self.current_rank}')


    def train(self, train_dataset, eval_dataset, eval_examples, tokenizer):
        device = torch.device(f"cuda:{self.args.gpu}")
        print(f'---------------{device}----------------')
        self.model = self.model.cuda()
        train_loader = create_dataloader(train_dataset, tokenizer, self.args.batch_size, shuffle=True)
        eval_loader = create_dataloader(eval_dataset, tokenizer, self.args.eval_batch_size)
        optimizer_grouped_parameters = [
            {
                "params": self.trainable_params_A,
                "weight_decay": 1e-4,
                "lr": self.args.lr_B,
            },
            {
                "params": self.trainable_params_B,
                "weight_decay": 1e-4,
                "lr": self.args.lr_B,
            },
            {
                "params": self.trainable_params_C,
                "weight_decay": 1e-4,
                "lr": self.args.lr_B,
            },
        ]
        self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters)
        num_update_steps_per_epoch = len(train_loader)
        # num_training_steps = self.args.num_epochs * num_update_steps_per_epoch
        num_training_steps = self.args.com_rounds * self.args.com_interval
        lr_scheduler = create_scheduler(self.args, num_training_steps, self.optimizer)
        if self.args.dataset == 'squad':
            test_results_list = {"exact_match": [], "f1": []}
        else:
            test_results_list = {"exact": [], "f1": []}

        loss_list = []
        cut_round = 0
        self.model.train()
        for epoch in range(self.args.num_epochs):
            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.args.num_epochs}", disable=self.args.rank%self.args.world_size!=0)

            for batch_idx, batch in enumerate(progress_bar):
                if cut_round >= self.args.com_rounds:
                    break
                # for batch_idx, batch in enumerate(train_loader):
                # print(f'Rank: {args.rank}')
                input_ids = batch["input_ids"].cuda()
                attention_mask = batch["attention_mask"].cuda()
                start_positions = batch["start_positions"].cuda()
                end_positions = batch["end_positions"].cuda()
                self.optimizer.zero_grad()
                loss = self.learn(input_ids, attention_mask, start_positions, end_positions)
                loss_list.append(round(loss.item(), 4))

                torch.nn.utils.clip_grad_norm_(self.trainable_params_A, self.args.max_grad_norm)
                torch.nn.utils.clip_grad_norm_(self.trainable_params_B, self.args.max_grad_norm)
                torch.nn.utils.clip_grad_norm_(self.trainable_params_C, self.args.max_grad_norm)
                self.optimizer.step()
                lr_scheduler.step()
                progress_bar.set_postfix({"Rank": self.args.rank, "Train loss": loss.item(), "r": self.current_rank, "lr": lr_scheduler.get_lr()})
                # print(f"Rank: {self.args.rank}, Train loss: {loss.item()}, r: {self.current_rank}, lr: {lr_scheduler.get_lr()}")

                if batch_idx%self.args.com_interval == 0:
                    # prune the adapter
                    # if batch_idx%20 == 0:
                    self.end_local_step()
                    norm_BA = [(b.detach()@a.detach()).norm() for (b, a) in zip(self.trainable_params_B, self.trainable_params_A)]
                    personal_norm_BA = copy.copy(norm_BA)
                    # compute the sum of all the clients' BA norm
                    for i in range(len(norm_BA)):
                        self.dist.reduce(norm_BA[i], dst=0, op=self.dist.ReduceOp.SUM)
                        self.dist.broadcast(norm_BA[i], src=0)
                    i, j = 0, 0
                    for n, p in self.model.named_parameters():
                        if p.requires_grad == True:
                            if 'lora_A' in n or  'lora_B' in n:
                                # print(n)
                                # if 'lora_A' in n and p.size(0) < self.args.rank_mat:
                                if 'lora_A' in n:
                                    zero_tensor = torch.zeros(self.args.rank_max-p.size(0), p.size(1), dtype=p.dtype, device=p.device)
                                    send_p = torch.cat([p, zero_tensor], dim=0)/self.args.world_size
                                    send_p *= personal_norm_BA[i]/norm_BA[i]
                                    i += 1
                                else:
                                    # print('zero_B')
                                    zero_tensor = torch.zeros(p.size(0), self.args.rank_max-p.size(1),  dtype=p.dtype, device=p.device)
                                    send_p = torch.cat([p, zero_tensor], dim=1)/self.args.world_size
                                    send_p *= personal_norm_BA[j]/norm_BA[j]
                                    j += 1

                                self.dist.reduce(send_p.data, dst=0, op=self.dist.ReduceOp.SUM)
                                self.dist.broadcast(send_p.data, src=0)

                                if 'lora_A' in n:
                                    p.data = send_p.data[:p.size(0), :]
                                    # self.trainable_params_A.append(p)
                                else:
                                    p.data = send_p.data[:, :p.size(1)]
                                    # self.trainable_params_B.append(p)

                            else:
                                send_p = p/self.args.world_size
                                self.dist.reduce(send_p.data, dst=0, op=self.dist.ReduceOp.SUM)
                                self.dist.broadcast(send_p.data, src=0)
                                p.data = send_p.data
                                # self.trainable_params_C.append(p)
                    cut_round += 1
                    # if batch_idx % 20 == 0:
                    self.start_local_steps()


            # Add a synchronization barrier before dist.all_gather
            print('Synchronization barrier')
            self.dist.barrier()
            eval_results = self.evaluate(eval_loader, eval_dataset, eval_examples)
            print(f'rank: {self.args.rank}, r:{self.current_rank}, {eval_results}')
            self.dist.barrier()
            # Convert the evaluation results to a tensor
            if self.args.dataset == 'squad':
                eval_results_tensor = torch.tensor([eval_results["exact_match"], eval_results["f1"]],
                                                   dtype=torch.float32).cuda()
            else:
                eval_results_tensor = torch.tensor([eval_results["exact"], eval_results["f1"]],
                                                   dtype=torch.float32).cuda()
            self.dist.reduce(eval_results_tensor, dst=0, op=self.dist.ReduceOp.SUM)

            # print(f'rank:{self.args.rank}, r:{self.current_rank.data}')
            rank_list = []

            if self.dist.get_rank() == 0:
                rank_list = [torch.zeros(1, dtype=torch.long).cuda() for _ in range(self.args.world_size)]
            self.dist.barrier()
            self.dist.gather(tensor=torch.tensor([self.current_rank], dtype=torch.long).cuda(), gather_list=rank_list, dst=0)

            if self.dist.get_rank() == 0:
                r = [i[0].data for i in rank_list]
                print(f"rank list: {r}")

            if self.args.rank % self.args.world_size == 0:
                avg_exact = (eval_results_tensor[0] / self.args.world_size).item()
                avg_f1 = (eval_results_tensor[1] / self.args.world_size).item()
                if self.args.dataset == 'squad':
                    test_results_list["exact_match"].append(avg_exact)
                else:
                    test_results_list["exact"].append(avg_exact)
                test_results_list["f1"].append(avg_f1)

                print(f"Epoch {epoch + 1}/{self.args.num_epochs}:")
                print(f"Average evaluation results: {test_results_list}")
                torch.save((test_results_list, loss_list), self.args.save_path + '.pkl')
                with open(self.args.save_path + '.txt', 'w') as f:
                    f.write(str({'Exp config': str(self.args), 'Average evaluation results': str(test_results_list)}))

            if cut_round >= self.args.com_rounds:
                if self.args.rank % self.args.world_size == 0:
                    print(f"Average evaluation results: {test_results_list}")
                break

    def evaluate(self, dataloader, eval_dataset, eval_examples):

        norm_BA = [(b.detach() @ a.detach()).norm() for (b, a) in zip(self.trainable_params_B, self.trainable_params_A)]
        personal_norm_BA = copy.copy(norm_BA)
        # compute the sum of all the clients' BA norm
        for i in range(len(norm_BA)):
            self.dist.reduce(norm_BA[i], dst=0, op=self.dist.ReduceOp.SUM)
            self.dist.broadcast(norm_BA[i], src=0)
        i, j = 0, 0
        for n, p in self.model.named_parameters():
            if p.requires_grad == True:
                if 'lora_A' in n or 'lora_B' in n:
                    # print(n)
                    # if 'lora_A' in n and p.size(0) < self.args.rank_mat:
                    if 'lora_A' in n:
                        zero_tensor = torch.zeros(self.args.rank_max - p.size(0), p.size(1), dtype=p.dtype,
                                                  device=p.device)
                        send_p = torch.cat([p, zero_tensor], dim=0) / self.args.world_size
                        send_p *= personal_norm_BA[i] / norm_BA[i]
                        i += 1
                    else:
                        # print('zero_B')
                        zero_tensor = torch.zeros(p.size(0), self.args.rank_max - p.size(1), dtype=p.dtype,
                                                  device=p.device)
                        send_p = torch.cat([p, zero_tensor], dim=1) / self.args.world_size
                        send_p *= personal_norm_BA[j] / norm_BA[j]
                        j += 1

                    self.dist.reduce(send_p.data, dst=0, op=self.dist.ReduceOp.SUM)
                    self.dist.broadcast(send_p.data, src=0)

                    if 'lora_A' in n:
                        p.data = send_p.data[:p.size(0), :]
                        # self.trainable_params_A.append(p)
                    else:
                        p.data = send_p.data[:, :p.size(1)]
                        # self.trainable_params_B.append(p)

                else:
                    send_p = p / self.args.world_size
                    self.dist.reduce(send_p.data, dst=0, op=self.dist.ReduceOp.SUM)
                    self.dist.broadcast(send_p.data, src=0)
                    p.data = send_p.data
                    # self.trainable_params_C.append(p)

        self.model.eval()
        progress_bar = tqdm(dataloader, desc="Evaluation", unit="batch")

        self.model.eval()
        metric = load_metric(self.args.dataset)
        all_start_logits = []
        all_end_logits = []
        for batch_idx, batch in enumerate(progress_bar):
            input_ids = batch["input_ids"].cuda()
            attention_mask = batch["attention_mask"].cuda()
            with torch.no_grad():
                outputs = self.model(input_ids, attention_mask=attention_mask)
                start_logits = outputs.start_logits
                end_logits = outputs.end_logits
                all_start_logits.append(start_logits.cpu().numpy())
                all_end_logits.append(end_logits.cpu().numpy())

        all_start_logits = np.concatenate(all_start_logits, axis=0)
        all_end_logits = np.concatenate(all_end_logits, axis=0)
        prediction = (all_start_logits, all_end_logits)
        predictions = postprocess_qa_predictions(eval_examples, eval_dataset, prediction, n_best_size=20,
                                                 max_answer_length=30)
        formatted_predictions = [{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in
                                 predictions.items()]
        references = [{"id": ex["id"], "answers": ex["answers"]} for ex in eval_examples]
        metric_result = metric.compute(predictions=formatted_predictions, references=references)

        return metric_result

    def learn(self, input_ids, attention_mask, start_positions, end_positions):
        outputs = self.model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        loss = outputs.loss
        loss.backward()
        return loss