import time
import torch
import torch.nn.functional as F
from misc.build_model import initialize_networks
from misc.utils import *
from modules.federated import ClientModule
from torch.cuda.amp import GradScaler
from peft import TaskType, LoftQConfig, LoraConfig, get_peft_model
from modules.logger import Logger
from torch.utils.data import DataLoader, SubsetRandomSampler
from transformers import AutoModelForSequenceClassification
import json
import torch.nn.init as init
from sklearn.metrics import precision_score, recall_score, f1_score

class Client(ClientModule):
    def __init__(self, args, w_id, g_id, sd):
        super(Client, self).__init__(args, w_id, g_id, sd)
        self.w_id = w_id

        self.first_lora_trained = False
        self.model, self.quantize_bits, self.lora_rank = initialize_networks(
            model=args.model, 
            n_classes=args.n_clss, 
            adapter=args.adapter,
            quantize=args.quantize,
            random_quantize=args.random_quantize,
            random_rank=args.random_rank,
            default_rank=args.rank,
            client_id=w_id
        )
        self.model.cuda(g_id)
        self.parameters = list(self.model.parameters())
        self.first_training_done = False
        
        self.iteration_count = 0
        self.max_iterations = 5
        self.rnd_count = 0

        load_path = f"your_local_path"
        
        self.W1_state_dict = torch.load(load_path)
        
        self.params_to_update = []
        self._initialize_params_to_update_with_adapter()

    def _initialize_params_to_update_with_adapter(self):
        self.params_to_update = []
        for name, param in self.model.named_parameters():
            if 'lora_A' in name or 'lora_B' in name:
                param.requires_grad = True
                self.params_to_update.append(param)
            else:
                param.requires_grad = False

    def _initialize_params_to_update_without_adapter(self):
        for name, param in self.model.named_parameters():
            if 'classifier' in name or 'pre_classifier' in name:
                self.params_to_update.append(name)

    def reset_parameters(self):
        for name, param in self.model.named_parameters():
            if name in self.params_to_update:
                if param.dim() > 1:
                    if param.dtype in [torch.float32, torch.bfloat16]:
                        torch.nn.init.xavier_uniform_(param)
                else:
                    if param.dtype in [torch.float32, torch.bfloat16]:
                        torch.nn.init.zeros_(param)

    def init_state(self):
        self.reset_parameters()
        self.optimizer = torch.optim.AdamW(self.parameters, lr=self.args.base_lr, weight_decay=self.args.weight_decay)
        self.log = {
            'lr': [], 'train_lss': [], 'train_acc': [],
            'ep_local_val_lss': [], 'ep_local_val_acc': [],
            'ep_local_test_lss': [], 'ep_local_test_acc': [],
            'rnd_local_test_lss': [], 'rnd_local_test_acc': [],
        }

    def save_state(self):
        torch_save(self.args.checkpt_path, f'{self.client_id}_state.pt', {
            'optimizer': self.optimizer.state_dict(),
            'model': get_state_dict(self.model),
            'log': self.log,
        })

    def load_state(self):
        loaded = torch_load(self.args.checkpt_path, f'{self.client_id}_state.pt')
        set_state_dict(self.model, loaded['model'], self.gpu_id)
        self.optimizer.load_state_dict(loaded['optimizer'])
        self.log = loaded['log']

    def load_state1(self, loaded):
        if self.curr_rnd >= self.args.fft:
            set_state_dict(self.model, loaded['model'], self.gpu_id, params_to_update=self.params_to_update)
        else:
            set_state_dict(self.model, loaded['model'], self.gpu_id)
        self.optimizer.load_state_dict(loaded['optimizer'])
        self.log = loaded['log']

    def update_state(self, client_state, client_id):
        client_state[client_id] = {
            'optimizer': self.optimizer.state_dict(),
            'model': get_state_dict(self.model),
            'log': self.log,
        }

    def on_receive_message(self, curr_rnd):
        self.curr_rnd = curr_rnd
        self.update(self.sd['global'])

    def update(self, update):
        if self.curr_rnd >= self.args.fft:
            set_state_dict(self.model, update['model'], self.gpu_id, params_to_update=self.params_to_update)
        else:
            set_state_dict(self.model, update['model'], self.gpu_id)

    def on_round_begin(self, client_id):
        if not self.first_training_done:
        
            self.train_first_stage(self.args.first_stage_epochs)
        
            self.merge_lora_params()   
        
            self.apply_loftq()
        
            self.merge_W1_with_AiBi()
        
            self.retrain_with_new_lora()

            self.rnd_count = 0

            self.train()
            self.transfer_to_server()

            self.first_training_done = True

        elif self.iteration_count < self.max_iterations:
            if self.rnd_count < self.args.n_rnds:

                self.train()
                self.transfer_to_server()
                self.rnd_count += 1
            else:

                self.merge_lora_params()

                self.apply_loftq()

                self.merge_lora_params()

                self.merge_W1_with_AiBi()

                self.retrain_with_new_lora()

                self.rnd_count = 0

                self.iteration_count += 1

                self.train()
                self.transfer_to_server()

        else:
            self.logger.print("Maximum iterations reached, skipping further training.")

    @torch.no_grad()
    def test(self):
        writer = {'loss': 0., 'acc': 0., 'step': 0}
        self.model.eval()

        eval_data = self.loader.test_pa
        all_labels = []
        all_preds = []

        for batch in eval_data:
            input_ids = batch['input_ids'].cuda(self.gpu_id, non_blocking=True)
            attention_mask = batch['attention_mask'].cuda(self.gpu_id, non_blocking=True)
            labels = batch['label'].cuda(self.gpu_id, non_blocking=True)
            outputs = self.model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            train_lss = (-F.one_hot(labels, self.args.n_clss) * logits.log_softmax(dim=-1)).sum(dim=-1)
            train_lss = train_lss.mean()
            writer['loss'] += train_lss

            preds = logits.argmax(dim=-1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

            writer['acc'] += torch.eq(preds, labels).float().mean().item()
            writer['step'] += 1

        loss = writer['loss'] / writer['step']
        accuracy = writer['acc'] / writer['step']

        precision = precision_score(all_labels, all_preds, average='macro')
        recall = recall_score(all_labels, all_preds, average='macro')
        f1 = f1_score(all_labels, all_preds, average='macro')

        self.logger.print(f'loss: {loss}, accuracy: {accuracy}, '
                          f'Precision: {precision}, Recall: {recall}, '
                          f'F1-score: {f1}')

        return accuracy
    
    
    def train(self, epochs=None):
        if epochs is None:
            epochs = self.args.n_eps
        st = time.time()
        self.loader.switch(self.client_id)

        for name, param in self.model.named_parameters():
            if 'lora' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

        for ep in range(epochs):
            self.model.train()
            client_data = self.loader.pa_loader
            
            writer = {'loss': 0., 'acc': 0., 'step': 0}
            

            for batch in client_data:
                self.optimizer.zero_grad()
                input_ids = batch['input_ids'].to(self.gpu_id)
                attention_mask = batch['attention_mask'].to(self.gpu_id)
                labels = batch['label'].to(self.gpu_id)

                outputs = self.model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                train_lss = (-F.one_hot(labels, self.args.n_clss) * logits.log_softmax(dim=-1)).sum(dim=-1)
                train_lss = train_lss.mean()
                train_lss.backward()
                self.optimizer.step()

                writer['loss'] += train_lss.item()
                writer['acc'] += torch.eq(logits.argmax(dim=-1), labels).float().mean().item()
                writer['step'] += 1

            train_acc = writer['acc'] / writer['step']
            train_lss = writer['loss'] / writer['step']
            self.log['train_lss'].append(train_lss)
            self.log['train_acc'].append(train_acc)

            if self.args.print:
                self.logger.print(
                    f'rnd:{self.curr_rnd + 1}, ep:{ep + 1},'
                    f'train_local_loss: {train_lss:.4f}, train_local_acc: {train_acc},'
                    f'lr: {self.get_lr()} ({(time.time() - st) * 1000:.4f}ms).'
                )
        test_local_acc = 66.6
        self.log['ep_local_test_acc'].append(test_local_acc)
        self.log['rnd_local_test_acc'].append(test_local_acc)

        self.rnd_local_test_acc = test_local_acc
        self.save_log()

    def train_first_stage(self, first_stage_epochs):
        self.curr_rnd = 0
        self.init_state()
        self.train(epochs=first_stage_epochs)
        self.first_lora_trained = True

    def merge_lora_params(self):
        lora_A_params = {}
        lora_B_params = {}
        model_params = list(self.model.named_parameters())

        for name, param in model_params:
            if 'lora_A' in name:
                lora_A_params[name] = param.data.clone()
            elif 'lora_B' in name:
                lora_B_params[name] = param.data.clone()

        for lora_A_name, lora_A in lora_A_params.items():
            lora_B_name = lora_A_name.replace('lora_A', 'lora_B')
            if lora_B_name in lora_B_params:
                lora_B = lora_B_params[lora_B_name]
                base_param_name = lora_A_name.replace('.lora_A', '')
                if base_param_name in self.model.state_dict():
                    base_param = self.model.state_dict()[base_param_name].clone()
                    delta_W = torch.matmul(lora_B, lora_A)
                    merged_param = base_param + delta_W
                    self.model.state_dict()[base_param_name].copy_(merged_param)

        for name, param in list(self.model.named_parameters()):
            if 'lora_A' in name or 'lora_B' in name:
                module_name, param_name = name.rsplit('.', 1)
                module = dict(self.model.named_modules())[module_name]
                if hasattr(module, param_name):
                    delattr(module, param_name)

    def apply_loftq(self):
        loftq_config = LoftQConfig(loftq_bits=self.quantize_bits, loftq_iter=5)
        lora_config = LoraConfig(
            task_type=TaskType.SEQ_CLS,
            inference_mode=True,
            r=self.lora_rank,
            lora_alpha=16,
            lora_dropout=0.1,
            target_modules=["q_lin", "v_lin"],
            init_lora_weights="loftq",
            loftq_config=loftq_config,
        )

        W2_model = get_peft_model(self.model, lora_config)

        prefix = "base_model.model."
        AiBi_state_dict = {
            k[len(prefix):] if k.startswith(prefix) else k: v.clone()
            for k, v in W2_model.state_dict().items()
            if 'lora_A' in k or 'lora_B' in k
        }

        self.AiBi_state_dict = AiBi_state_dict
    
    def merge_W1_with_AiBi(self):
        W1_state_dict = self.W1_state_dict
        AiBi_state_dict = self.AiBi_state_dict

        W3_state_dict = {k: v.clone() for k, v in W1_state_dict.items()}
        for lora_A_name, lora_A in AiBi_state_dict.items():
            if 'lora_A' in lora_A_name:
                lora_B_name = lora_A_name.replace('lora_A', 'lora_B')
                if lora_B_name in AiBi_state_dict:
                    lora_B = AiBi_state_dict[lora_B_name]
                    base_param_name = lora_A_name.replace('.lora_A', '')
                    if base_param_name in W3_state_dict:
                        base_param = W3_state_dict[base_param_name].clone()
                        delta_W = torch.matmul(lora_B, lora_A)
                        W3_state_dict[base_param_name].data = base_param + delta_W

        for name in list(W3_state_dict.keys()):
            if 'lora_A' in name or 'lora_B' in name:
                del W3_state_dict[name]

        for k, v in self.model.state_dict().items():
            if k not in W3_state_dict:
                W3_state_dict[k] = v

        self.model.load_state_dict(W3_state_dict, strict=False)

    def retrain_with_new_lora(self):
        new_lora_config = LoraConfig(
            r=self.lora_rank,
            lora_alpha=16,
            target_modules=["q_lin", "v_lin"],
            lora_dropout=0.1,
            bias="none"
        )
        self.model = get_peft_model(self.model, new_lora_config)

        self.params_to_update = []
        for name, param in self.model.named_parameters():
            if 'lora_A' in name or 'lora_B' in name:
                param.requires_grad = True
                self.params_to_update.append(name)
            else:
                param.requires_grad = False

        self.optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, self.model.parameters()), 
            lr=self.args.lr, 
            weight_decay=self.args.weight_decay
        )
    
        self.train(epochs=1)

        if self.w_id == 0:
            save_dir = 'your_local_path'
            save_filename = 'name.pth'
            save_path = os.path.join(save_dir, save_filename)
            os.makedirs(save_dir, exist_ok=True)
            torch.save(self.model, save_path)
        self.transfer_to_server()

    def transfer_to_server(self):
        
        lora_params = {name: param.clone().detach() for name, param in self.model.named_parameters() if 'lora_A' in name or 'lora_B' in name}

        self.sd[self.client_id] = {
            'model': lora_params,
            'train_size': self.loader.train_size,
            'rnd_local_test_acc': self.rnd_local_test_acc
        }
