import time
import numpy as np
from misc.build_model import initialize_networks
from misc.utils import *
from modules.federated import ServerModule
import torch.nn.functional as F
from peft import TaskType
from concurrent.futures import ThreadPoolExecutor
from sklearn.metrics import precision_score, recall_score, f1_score

class Server(ServerModule):
    def __init__(self, args, sd, gpu_server):
        super(Server, self).__init__(args, sd, gpu_server)
        self.model, _, _ = initialize_networks(
            model=args.model, n_classes=args.n_clss, adapter=args.adapter, 
            quantize=args.quantize, random_quantize=args.random_quantize
        )
        self.model.cuda(self.gpu_id)
        self.test_acc = 0
        self.best_test_acc = 0
        self.params_to_update = None
        if self.args.adapter != '' and self.params_to_update is None:
            self.params_to_update = [name for name, param in self.model.named_parameters() if 'classifier' in name or 'pre_classifier' in name or self.args.adapter in name]

    def on_round_begin(self, curr_rnd):
        self.round_begin = time.time()
        self.curr_rnd = curr_rnd
        self.sd['global'] = self.get_weights()

    def calculate_average_test_accuracy(self, updated):
        total_test_acc = 0.0
        num_clients = len(updated)
        for c_id in updated:
            total_test_acc += self.sd[c_id]['rnd_local_test_acc']
        average_test_acc = total_test_acc / num_clients
        if self.best_test_acc <= average_test_acc:
            self.best_test_acc = average_test_acc

    
    def on_round_complete(self, updated):
        load_path = 'load from your_local_path'
        self.model = torch.load(load_path)
        self.model.cuda(self.gpu_id)

        self.update(updated)

        if self.args.print:
            self.logger.print(f'{self.curr_rnd + 1} Avg test acc: {self.test_acc * 100:.1f}%')
            self.logger.print(f'Avg best test acc: {self.best_test_acc * 100:.1f}%')
    
        if self.curr_rnd + 1 == self.args.n_rnds:
            self.logger.print(f'Avg test acc: {self.test_acc * 100:.1f}%')

        self.save_state()

    @torch.no_grad()
    def test(self):
        writer = {'loss': 0., 'acc': 0., 'step': 0}
        self.model.eval()

        eval_data = self.loader.test_pa
        all_labels = []
        all_preds = []

        for batch in eval_data:
            input_ids = batch['input_ids'].cuda(self.gpu_id, non_blocking=True)
            attention_mask = batch['attention_mask'].cuda(self.gpu_id, non_blocking=True)
            labels = batch['label'].cuda(self.gpu_id, non_blocking=True)
            outputs = self.model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            train_lss = (-F.one_hot(labels, self.args.n_clss) * logits.log_softmax(dim=-1)).sum(dim=-1)
            train_lss = train_lss.mean()
            writer['loss'] += train_lss
        
            preds = logits.argmax(dim=-1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
        
            writer['acc'] += torch.eq(preds, labels).float().mean().item()
            writer['step'] += 1

        loss = writer['loss'] / writer['step']
        accuracy = writer['acc'] / writer['step']

        precision = precision_score(all_labels, all_preds, average='macro')
        recall = recall_score(all_labels, all_preds, average='macro')
        f1 = f1_score(all_labels, all_preds, average='macro')

        self.logger.print(f'loss: {loss}, accuracy: {accuracy}, '
                          f'Precision: {precision}, Recall: {recall}, '
                          f'F1-score: {f1}')

        return accuracy
    
    def update(self, updated):
        st = time.time()
        local_weights = []
        local_train_sizes = []

        for c_id in updated:
            local_weights.append(self.sd[c_id]['model'].copy())
            local_train_sizes.append(self.sd[c_id]['train_size'])
            del self.sd[c_id]

        st = time.time()
         
        ratio = (np.array(local_train_sizes) / np.sum(local_train_sizes)).tolist()

        print(f"Ratio for aggregation: {ratio}")

        aggregated_lora_params = self.aggregate(local_weights, ratio)

        current_state_dict = self.model.state_dict()

        for name, param in aggregated_lora_params.items():
            if name in current_state_dict:
                try:
                    current_state_dict[name].copy_(param)
                except Exception as e:
                    print(f"Skipping parameter {name}: {str(e)}")
            else:
                print(f"Parameter {name} not found in the current model, skipping.")
        print("LoRA parameters aggregation and update completed.")

        self.test_acc = self.test()

        if self.args.print:
            self.logger.print(f'global model has been updated ({time.time() - st:.2f}s)')
            self.test_acc = self.test()
            if self.test_acc > self.best_test_acc:
                self.best_test_acc = self.test_acc
                

    def set_weights(self, model, state_dict):
        set_state_dict(model, state_dict, self.gpu_id, params_to_update=self.params_to_update)
    
    def get_weights(self):
        return {'model': get_state_dict(self.model)}

    def save_state(self):
        torch_save(self.args.checkpt_path, 'server_state.pt', {
            'model': get_state_dict(self.model),
        })