#!/usr/bin/env python
# coding: utf-8


import os
import datasets
from datasets import load_dataset, DatasetDict
import torch
import json
from peft import LoraConfig, TaskType, get_peft_model, PeftModel, get_peft_model_state_dict
from peft.tuners.lora import LoraLayer
import transformers
from transformers import (
    CONFIG_MAPPING,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
    DataCollatorWithPadding,
    GenerationConfig
)
import evaluate
from tqdm import tqdm
import numpy as np
from typing import Dict, List, Tuple
from collections import defaultdict
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import gc
import matplotlib.pyplot as plt
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform
import random


def partition_multi_task_dataset(task_name_list, tokenizer, alpha=0.5, train_samples_per_client=None, 
                              test_samples_per_client=None, seed=42):
    
    np.random.seed(seed)
    random.seed(seed)
    
    num_clients = len(task_name_list)
    print(f"Creating datasets for {num_clients} clients with tasks: {task_name_list}")
    
    unique_tasks = set(task_name_list)
    task_datasets = {}
    task_info = {}
    
    task_available_indices = {}
    
    for task_name in unique_tasks:
        raw_datasets = load_dataset("glue", task_name)
        
        num_labels = len(set(raw_datasets["train"]["label"]))

        if task_name == "sst2":
            sentence1_key, sentence2_key = "sentence", None
        elif task_name == "qnli":
            sentence1_key, sentence2_key = "question", "sentence"
        elif task_name == "mnli":
            sentence1_key, sentence2_key = "premise", "hypothesis"
        elif task_name == "qqp":
            sentence1_key, sentence2_key = "question1", "question2"
        elif task_name == "rte" or task_name == "mrpc":
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            raise ValueError(f"Task {task_name} not supported")
        
        validation_key = "validation_matched" if task_name == "mnli" else "validation"
        
        task_datasets[task_name] = {
            "raw_datasets": raw_datasets,
            "num_labels": num_labels,
            "sentence1_key": sentence1_key,
            "sentence2_key": sentence2_key,
            "validation_key": validation_key
        }
        
        train_labels = np.array([example['label'] for example in raw_datasets["train"]])
        val_labels = np.array([example['label'] for example in raw_datasets[validation_key]])
        
        task_available_indices[task_name] = {
            "train": [np.where(train_labels == k)[0] for k in range(num_labels)],
            validation_key: [np.where(val_labels == k)[0] for k in range(num_labels)]
        }
        
        if task_name == "mnli" and "validation_mismatched" in raw_datasets:
            mismatched_labels = np.array([example['label'] for example in raw_datasets["validation_mismatched"]])
            task_available_indices[task_name]["validation_mismatched"] = [
                np.where(mismatched_labels == k)[0] for k in range(num_labels)
            ]
        
        for split in task_available_indices[task_name]:
            for k in range(num_labels):
                indices = task_available_indices[task_name][split][k].tolist()
                random.shuffle(indices)
                task_available_indices[task_name][split][k] = np.array(indices)
        
        print(f"Dataset {task_name} has {num_labels} classes")
    
    task_client_counts = {}
    for task in unique_tasks:
        task_client_counts[task] = sum(1 for t in task_name_list if t == task)
    
    def partition_dataset_for_client(client_id, task_name, client_seed, task_client_idx):
        client_random = random.Random(client_seed)
        client_np_random = np.random.RandomState(client_seed)
        
        task_data = task_datasets[task_name]
        raw_datasets = task_data["raw_datasets"]
        num_labels = task_data["num_labels"]
        sentence1_key = task_data["sentence1_key"]
        sentence2_key = task_data["sentence2_key"]
        validation_key = task_data["validation_key"]

        n_classes = num_labels

        proportion = np.ones(n_classes) / n_classes
        
        min_proportion = 0.01 

        small_prop_indices = np.where(proportion < min_proportion)[0]
        
        if len(small_prop_indices) > 0:
            deficit = min_proportion * len(small_prop_indices) - proportion[small_prop_indices].sum()

            large_prop_indices = np.where(proportion > min_proportion)[0]
            
            if len(large_prop_indices) > 0:
                reduction_per_class = deficit / len(large_prop_indices)
                proportion[large_prop_indices] -= reduction_per_class
                proportion[small_prop_indices] = min_proportion

                proportion = proportion / proportion.sum()

        print(f"Client {client_id} ({task_name}) class proportions: {proportion}")

        train_size = train_samples_per_client
        if train_size is None:
            train_size = len(raw_datasets["train"]) // task_client_counts[task_name]
        
        test_size = test_samples_per_client
        if test_size is None:
            test_size = len(raw_datasets[validation_key]) // task_client_counts[task_name]
        
        print(f"Client {client_id} ({task_name}) train samples: {train_size}, test samples: {test_size}")

        min_samples_per_class = 1

        client_train_class_counts = np.round(proportion * train_size).astype(int)
        client_val_class_counts = np.round(proportion * test_size).astype(int)

        for k in range(n_classes):
            if client_train_class_counts[k] < min_samples_per_class:
                client_train_class_counts[k] = min_samples_per_class
            
            if client_val_class_counts[k] < min_samples_per_class:
                client_val_class_counts[k] = min_samples_per_class
        
        train_diff = train_size - client_train_class_counts.sum()
        if train_diff != 0:
            available_train = [len(task_available_indices[task_name]["train"][k]) for k in range(n_classes)]
            
            if train_diff > 0:
                classes_to_adjust = np.argsort(-np.array(available_train))
                
                for idx in classes_to_adjust:
                    if train_diff > 0 and available_train[idx] > 0:
                        add_count = min(train_diff, available_train[idx])
                        client_train_class_counts[idx] += add_count
                        train_diff -= add_count
                    
                    if train_diff == 0:
                        break
            else:
                classes_to_adjust = np.argsort(-client_train_class_counts)
                
                for idx in classes_to_adjust:
                    
                    if train_diff < 0 and client_train_class_counts[idx] > min_samples_per_class:
                        remove_count = min(-train_diff, client_train_class_counts[idx] - min_samples_per_class)
                        client_train_class_counts[idx] -= remove_count
                        train_diff += remove_count
                    
                    if train_diff == 0:
                        break
        
        val_diff = test_size - client_val_class_counts.sum()
        if val_diff != 0:
            available_val = [len(task_available_indices[task_name][validation_key][k]) for k in range(n_classes)]
            
            if val_diff > 0:
                classes_to_adjust = np.argsort(-np.array(available_val))
                
                for idx in classes_to_adjust:
                    if val_diff > 0 and available_val[idx] > 0:
                        add_count = min(val_diff, available_val[idx])
                        client_val_class_counts[idx] += add_count
                        val_diff -= add_count
                    
                    if val_diff == 0:
                        break
            else:
                classes_to_adjust = np.argsort(-client_val_class_counts)
                
                for idx in classes_to_adjust:
                    if val_diff < 0 and client_val_class_counts[idx] > min_samples_per_class:
                        remove_count = min(-val_diff, client_val_class_counts[idx] - min_samples_per_class)
                        client_val_class_counts[idx] -= remove_count
                        val_diff += remove_count
                    
                    if val_diff == 0:
                        break
        
        client_train_indices = []
        client_val_indices = []
        
        for k in range(n_classes):
            if client_train_class_counts[k] > 0:
                train_needed = client_train_class_counts[k]
                available_indices = task_available_indices[task_name]["train"][k]
                
                if train_needed <= len(available_indices):
                    client_train_indices.extend(available_indices[:train_needed].tolist())
                    task_available_indices[task_name]["train"][k] = available_indices[train_needed:]
                else:
                    client_train_indices.extend(available_indices.tolist())
                    
                    additional_needed = train_needed - len(available_indices)
                    if additional_needed > 0:
                        original_indices = np.where(np.array([example['label'] for example in raw_datasets["train"]]) == k)[0]
                        used_indices = set(np.concatenate([client_train_indices] + 
                                            [c_indices for c_idx, c_indices in enumerate(task_available_indices[task_name]["train"]) 
                                            if c_idx != k]))
                        available_pool = [idx for idx in original_indices if idx not in used_indices]
                        
                        if available_pool:
                            additional = client_np_random.choice(
                                available_pool, 
                                min(additional_needed, len(available_pool)), 
                                replace=False
                            ).tolist()
                        else:
                            additional = client_np_random.choice(
                                original_indices, 
                                additional_needed, 
                                replace=True
                            ).tolist()
                        
                        client_train_indices.extend(additional)
                    
                    task_available_indices[task_name]["train"][k] = np.array([])
            
            if client_val_class_counts[k] > 0:
                val_needed = client_val_class_counts[k]
                available_indices = task_available_indices[task_name][validation_key][k]
                
                if val_needed <= len(available_indices):
                    client_val_indices.extend(available_indices[:val_needed].tolist())
                    task_available_indices[task_name][validation_key][k] = available_indices[val_needed:]
                else:
                    client_val_indices.extend(available_indices.tolist())
                    
                    additional_needed = val_needed - len(available_indices)
                    if additional_needed > 0:
                        original_indices = np.where(np.array([example['label'] for example in raw_datasets[validation_key]]) == k)[0]
                        used_indices = set(np.concatenate([client_val_indices] + 
                                            [c_indices for c_idx, c_indices in enumerate(task_available_indices[task_name][validation_key]) 
                                            if c_idx != k]))
                        available_pool = [idx for idx in original_indices if idx not in used_indices]
                        
                        if available_pool:
                            additional = client_np_random.choice(
                                available_pool, 
                                min(additional_needed, len(available_pool)), 
                                replace=False
                            ).tolist()
                        else:
                            additional = client_np_random.choice(
                                original_indices, 
                                additional_needed, 
                                replace=True
                            ).tolist()
                        
                        client_val_indices.extend(additional)
                    
                    task_available_indices[task_name][validation_key][k] = np.array([])
        
        max_length = 128
        def preprocess_function(examples):
            texts = (
                (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
            )
            result = tokenizer(*texts, padding=False, max_length=max_length, truncation=True)
            
            if "label" in examples:
                result["labels"] = examples["label"]
            return result
        
        client_train_data = raw_datasets["train"].select(client_train_indices)
        client_val_data = raw_datasets[validation_key].select(client_val_indices)
        
        client_train_tokenized = client_train_data.map(
            preprocess_function,
            batched=True,
            remove_columns=raw_datasets["train"].column_names,
            desc=f"Tokenizing train data for client {client_id}",
        )
        
        client_val_tokenized = client_val_data.map(
            preprocess_function,
            batched=True,
            remove_columns=raw_datasets[validation_key].column_names,
            desc=f"Tokenizing validation data for client {client_id}",
        )
        
        client_dict = {
            "train": client_train_tokenized,
            "validation": client_val_tokenized
        }
        
        if task_name == "mnli" and "validation_mismatched" in raw_datasets:
            mismatched_val_indices = []
            
            mismatched_size = test_size
            client_mismatched_class_counts = np.round(proportion * mismatched_size).astype(int)
            
            for k in range(n_classes):
                if client_mismatched_class_counts[k] < min_samples_per_class:
                    client_mismatched_class_counts[k] = min_samples_per_class
            
            diff = mismatched_size - client_mismatched_class_counts.sum()
            if diff != 0:
                if diff > 0:
                    classes_to_adjust = np.argsort(-proportion)
                    for idx in classes_to_adjust:
                        if diff > 0:
                            client_mismatched_class_counts[idx] += 1
                            diff -= 1
                        
                        if diff == 0:
                            break
                else:
                    classes_to_adjust = np.argsort(-client_mismatched_class_counts)
                    for idx in classes_to_adjust:
                        if diff < 0 and client_mismatched_class_counts[idx] > min_samples_per_class:
                            remove_count = min(-diff, client_mismatched_class_counts[idx] - min_samples_per_class)
                            client_mismatched_class_counts[idx] -= remove_count
                            diff += remove_count
                        
                        if diff == 0:
                            break
            
            for k in range(n_classes):
                if client_mismatched_class_counts[k] > 0:
                    mismatched_needed = client_mismatched_class_counts[k]
                    available_indices = task_available_indices[task_name]["validation_mismatched"][k]
                    
                    if mismatched_needed <= len(available_indices):
                        mismatched_val_indices.extend(available_indices[:mismatched_needed].tolist())
                        task_available_indices[task_name]["validation_mismatched"][k] = available_indices[mismatched_needed:]
                    else:
                        mismatched_val_indices.extend(available_indices.tolist())
                        
                        additional_needed = mismatched_needed - len(available_indices)
                        if additional_needed > 0:
                            original_indices = np.where(np.array([example['label'] for example in raw_datasets["validation_mismatched"]]) == k)[0]
                            used_indices = set(np.concatenate([mismatched_val_indices] + 
                                                [c_indices for c_idx, c_indices in enumerate(task_available_indices[task_name]["validation_mismatched"]) 
                                                if c_idx != k]))
                            available_pool = [idx for idx in original_indices if idx not in used_indices]
                            
                            if available_pool:
                                additional = client_np_random.choice(
                                    available_pool, 
                                    min(additional_needed, len(available_pool)), 
                                    replace=False
                                ).tolist()
                            else:
                                additional = client_np_random.choice(
                                    original_indices, 
                                    additional_needed, 
                                    replace=True
                                ).tolist()
                            
                            mismatched_val_indices.extend(additional)
                        
                        task_available_indices[task_name]["validation_mismatched"][k] = np.array([])
            
            client_mismatched_data = raw_datasets["validation_mismatched"].select(mismatched_val_indices)
            client_mismatched_tokenized = client_mismatched_data.map(
                preprocess_function,
                batched=True,
                remove_columns=raw_datasets["validation_mismatched"].column_names,
                desc=f"Tokenizing mismatched validation for client {client_id}",
            )
            client_dict["validation_mismatched"] = client_mismatched_tokenized
        
        return DatasetDict(client_dict)
    
    client_datasets = {}
    
    task_client_indices = {task: 0 for task in unique_tasks}
    
    for client_id in range(num_clients):
        task_name = task_name_list[client_id]
        
        client_seed = seed * 10000 + client_id
        
        task_client_idx = task_client_indices[task_name]
        task_client_indices[task_name] += 1
        
        client_datasets[client_id] = partition_dataset_for_client(
            client_id, task_name, client_seed, task_client_idx
        )
        
        task_info[client_id] = {
            "task_name": task_name,
            "num_labels": task_datasets[task_name]["num_labels"]
        }
        
        print(f"\nClient {client_id} ({task_name}) dataset sizes:")
        print(f"  Training: {len(client_datasets[client_id]['train'])}")
        print(f"  Validation: {len(client_datasets[client_id]['validation'])}")
        if task_name == "mnli" and "validation_mismatched" in client_datasets[client_id]:
            print(f"  Validation Mismatched: {len(client_datasets[client_id]['validation_mismatched'])}")
        
        for split in client_datasets[client_id]:
            labels = [example["labels"] for example in client_datasets[client_id][split]]
            unique, counts = np.unique(labels, return_counts=True)
            class_counts = dict(zip(unique, counts))
            print(f"  Class distribution in {split}: {class_counts}")
    
    return client_datasets, task_info


class Client:
    def __init__(self, client_id, task_name, tokenizer, model_name, num_clients, rank=8, lora_n=4, adaptive=False, cache_path="./output", idx=None):
        self.client_id = client_id
        self.task_name = task_name
        self.tokenizer = tokenizer
        self.model_name = model_name
        self.num_clients = num_clients
        self.rank = rank
        self.lora_n = lora_n
        self.adaptive = adaptive
        self.cache_path = cache_path
        self.local_model = None
        self.current_params = None
        self.datasets = None
        self.num_labels = None
        self.idx = idx
        
    def set_dataset(self, dataset, num_labels):
        self.datasets = dataset
        self.num_labels = num_labels
    
    def load_model(self):
        if self.local_model is None:
            self.local_model = AutoModelForSequenceClassification.from_pretrained(
                self.model_name,
                cache_dir=self.cache_path,
                num_labels=self.num_labels,
            )
            
            peft_config = LoraConfig(
                task_type=TaskType.SEQ_CLS,
                target_modules=["query", "value"],
                inference_mode=False,
                r=self.rank,
                lora_alpha=16,
                lora_dropout=0.05,
                lora_nums=self.lora_n,
                adaptive=self.adaptive,
                idx=self.idx,
                k=self.lora_n,
                bias="none"
            )
            
            self.local_model = get_peft_model(self.local_model, peft_config)
            
            if self.current_params is not None:
                self.load_params(self.current_params)
    
    def unload_model(self):
        if self.local_model is not None:
            self.current_params = self.get_lora_params()['params']
            del self.local_model
            self.local_model = None
            torch.cuda.empty_cache()
            gc.collect()
        
    def get_lora_params(self):
        lora_params = {
            'client_id': self.client_id,
            'params': {}
        }
        for name, param in self.local_model.named_parameters():
            if 'lora_A' in name or 'lora_B' in name or 'lora_route' in name or 'classifier' in name:
                lora_params['params'][name] = param.data.clone()
        return lora_params
    
    def get_lora_params_and_save_by_module(self, round_id, personal_dir):
        lora_params = {
            'client_id': self.client_id,
            'params': {}
        }

        target_modules = ["query", "value"]
        lora_A_dict = {module: [] for module in target_modules}
        lora_B_dict = {module: [] for module in target_modules}

        for name, param in self.local_model.named_parameters():
            if 'lora_A' in name or 'lora_B' in name or 'lora_route' in name or 'classifier' in name:
                lora_params['params'][name] = param.data.clone()

            for module in target_modules:
                if module in name:
                    if 'lora_A0' in name:
                        lora_A_dict[module].append(param.data.clone().cpu().numpy())
                    elif 'lora_B0' in name:
                        lora_B_dict[module].append(param.data.clone().cpu().numpy())

        param_dir = os.path.join(personal_dir, "lora_params")
        os.makedirs(param_dir, exist_ok=True)

        for module in target_modules:
            if lora_A_dict[module]:  
                np.save(os.path.join(param_dir, f'{module}_lora_A_client_{self.client_id}_{round_id}.npy'), np.array(lora_A_dict[module]))
            if lora_B_dict[module]:  
                np.save(os.path.join(param_dir, f'{module}_lora_B_client_{self.client_id}_{round_id}.npy'), np.array(lora_B_dict[module]))

        return lora_params
    
    def load_params(self, params_or_path):
        if isinstance(params_or_path, dict):
            params_to_load = params_or_path['params'] if 'params' in params_or_path else params_or_path
            self.local_model.load_state_dict(params_to_load, strict=False)
        else:
            self.local_model.load_adapter(params_or_path, adapter_name="default")
            

    def local_training(self, lr=2e-4, epochs=1, batch_size=32, gradient_accumulation_steps=1, lora_client_map=None):
        self.local_model.train()

        if lora_client_map is None:
            raise ValueError("lora_client_map is required for local_training after warmup")

        client_lora_group = None
        for lora_idx, client_indices in lora_client_map.items():
            if self.client_id in client_indices:
                client_lora_group = int(lora_idx)
                break

        if client_lora_group is None:
            print(f"Client {self.client_id} is a dummy client or not found in lora_client_map")
            print(f"Training all LoRA modules for client {self.client_id}")

            for name, param in self.local_model.named_parameters():
                if 'lora_A' in name or 'lora_B' in name or 'lora_route' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        else:
            print(f"Client {self.client_id} belongs to LoRA group {client_lora_group}")
            for name, param in self.local_model.named_parameters():
                lora_a_pattern = f'lora_A{client_lora_group}'
                lora_b_pattern = f'lora_B{client_lora_group}'

                if lora_a_pattern in name or lora_b_pattern in name or 'lora_route' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False

        trainable_params = [p for p in self.local_model.parameters() if p.requires_grad]
        print(f"Number of trainable parameters: {len(trainable_params)}")
        if len(trainable_params) == 0:
            raise ValueError("No trainable parameters found!")

        training_args = TrainingArguments(
            output_dir=f"{self.cache_path}/{self.rank}_{self.lora_n}_proposed/client_{self.client_id}_checkpoints",
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=0,
            num_train_epochs=epochs,
            learning_rate=lr,
            fp16=True,
            logging_steps=5,
            optim="adamw_torch",
            weight_decay=0.05,
            evaluation_strategy="no",
            save_strategy="no",
            save_total_limit=1,
            remove_unused_columns=False,
            gradient_checkpointing=False
        )

        trainer = Trainer(
            model=self.local_model,
            args=training_args,
            train_dataset=self.datasets["train"],
            tokenizer=self.tokenizer,
            data_collator=DataCollatorWithPadding(self.tokenizer)
        )

        trainer.train()
        
    def evaluate_model(self, output_file=None):
        self.local_model.eval()
        def compute_metrics(eval_pred):
            logits, labels = eval_pred
            predictions = np.argmax(logits, axis=-1)
            return {"accuracy": (predictions == labels).astype(np.float32).mean().item()}

        eval_args = TrainingArguments(
            output_dir=f"{self.cache_path}/temp_eval_output",
            per_device_eval_batch_size=256,
            fp16=True,  
            report_to="none" 
        )

        metrics = Trainer(
            model=self.local_model,
            args=eval_args,
            eval_dataset=self.datasets["validation"],
            tokenizer=self.tokenizer,
            data_collator=DataCollatorWithPadding(self.tokenizer),
            compute_metrics=compute_metrics 
        ).evaluate()

        print(f"Evaluation metrics for client {self.client_id} on validation dataset:")
        print(metrics)

        if output_file:
            with open(output_file, 'w') as f:
                json.dump({
                    'client_id': self.client_id,
                    'task': self.task_name,
                    'dataset_type': 'validation',
                    'metrics': metrics
                }, f, indent=2)
            print(f'The output file is stored at {output_file}')

        return metrics


class WarmupClient(Client):
    def __init__(self, client_id, task_name, tokenizer, model_name, num_clients, rank=8, cache_path="./output"):
        
        super().__init__(client_id, task_name, tokenizer, model_name, num_clients, rank, lora_n=1, adaptive=False, cache_path=cache_path)
        
    def load_model(self):
        
        if self.local_model is None:
            self.local_model = AutoModelForSequenceClassification.from_pretrained(
                self.model_name,
                cache_dir=self.cache_path,
                num_labels=self.num_labels,
            )
            
            peft_config = LoraConfig(
                task_type=TaskType.SEQ_CLS,
                target_modules=["query", "value"],
                inference_mode=False,
                r=self.rank,
                lora_alpha=16,
                lora_dropout=0.05,
                lora_nums=1,
                adaptive=False, 
                idx=0,  
                k=1, 
                bias="none"
            )
            
            self.local_model = get_peft_model(self.local_model, peft_config)
            
            if self.current_params is not None:
                self.load_params(self.current_params)

    def local_training(self, lr=2e-4, epochs=1, batch_size=32, gradient_accumulation_steps=1):
        self.local_model.train()

        for name, param in self.local_model.named_parameters():
            if 'lora_A' in name or 'lora_B' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

        trainable_params = [p for p in self.local_model.parameters() if p.requires_grad]
        print(f"Number of trainable parameters: {len(trainable_params)}")
        if len(trainable_params) == 0:
            raise ValueError("No trainable parameters found!")

        training_args = TrainingArguments(
            output_dir=f"{self.cache_path}/warmup/client_{self.client_id}_checkpoints",
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=0,
            num_train_epochs=epochs,
            learning_rate=lr,
            fp16=True,
            logging_steps=5,
            optim="adamw_torch",
            weight_decay=0.05,
            evaluation_strategy="no",
            save_strategy="no",
            save_total_limit=1,
            remove_unused_columns=False,
            gradient_checkpointing=False
        )

        trainer = Trainer(
            model=self.local_model,
            args=training_args,
            train_dataset=self.datasets["train"],
            tokenizer=self.tokenizer,
            data_collator=DataCollatorWithPadding(self.tokenizer)
        )

        trainer.train()

    def get_lora_params(self):
        lora_params = {
            'client_id': self.client_id,
            'params': {}
        }
        for name, param in self.local_model.named_parameters():
            if 'lora_A' in name or 'lora_B' in name or 'classifier' in name:
                lora_params['params'][name] = param.data.clone()
        return lora_params
    
    def get_lora_params_and_save_by_module(self, round_id, personal_dir):
        lora_params = {
            'client_id': self.client_id,
            'params': {}
        }

        target_modules = ["query", "value"]
        lora_A_dict = {module: [] for module in target_modules}
        lora_B_dict = {module: [] for module in target_modules}

        for name, param in self.local_model.named_parameters():
            if 'lora_A' in name or 'lora_B' in name or 'lora_route' in name or 'classifier' in name:
                lora_params['params'][name] = param.data.clone()

            for module in target_modules:
                if module in name:
                    if 'lora_A0' in name:
                        lora_A_dict[module].append(param.data.clone().cpu().numpy())
                    elif 'lora_B0' in name:
                        lora_B_dict[module].append(param.data.clone().cpu().numpy())

        param_dir = os.path.join(personal_dir, "lora_params")
        print(param_dir)
        os.makedirs(param_dir, exist_ok=True)

        for module in target_modules:
            if lora_A_dict[module]:  
                np.save(os.path.join(param_dir, f'{module}_lora_A_client_{self.client_id}_{round_id}.npy'), np.array(lora_A_dict[module]))
            if lora_B_dict[module]: 
                np.save(os.path.join(param_dir, f'{module}_lora_B_client_{self.client_id}_{round_id}.npy'), np.array(lora_B_dict[module]))

        return lora_params


class Server:
    def __init__(self, clients_num: int, device: str = "cuda"):
        self.clients_num = clients_num
        self.device = device
        self.lora_client_map = None  

    def aggregation_warmup(self, route_aggregation: bool, params: List, lora_client_map=None) -> List[Dict]:
        gpu_params = [
            {k: v.to(self.device) for k, v in client_params.items()}
            for client_params in params
        ]

        num_clients = len(gpu_params)
        aggregated_results = [{} for _ in range(num_clients)]

        final_warmup_round = lora_client_map is not None

        if final_warmup_round:
            self.lora_client_map = lora_client_map
            print("Final warmup round, preparing transition to clustered LoRA")

            for client_idx in range(num_clients):
                for param_name, param_value in gpu_params[client_idx].items():
                    aggregated_results[client_idx][param_name] = param_value

            client_to_group = {}
            for group_idx, clients in lora_client_map.items():
                for client in clients:
                    client_to_group[client] = int(group_idx)

            for group_idx, group_clients in lora_client_map.items():
                group_idx = int(group_idx)

                if not group_clients:
                    continue

                print(f"Processing group {group_idx} with clients {group_clients}")

                valid_clients = [c for c in group_clients if c < num_clients]

                if not valid_clients:
                    continue

                for base_param_name in list(gpu_params[0].keys()):
                    if 'lora_A0' in base_param_name or 'lora_B0' in base_param_name:
                        target_param_name = base_param_name.replace('0', str(group_idx))

                        try:
                            stacked_params = torch.stack([
                                gpu_params[i][base_param_name]
                                for i in valid_clients if base_param_name in gpu_params[i]
                            ]).to(self.device)

                            if stacked_params.size(0) > 0:
                                avg_param = stacked_params.mean(dim=0)

                                for client_idx in group_clients:
                                    if client_idx < num_clients:
                                        aggregated_results[client_idx][target_param_name] = avg_param
                        except Exception as e:
                            print(f"Error aggregating {base_param_name} for group {group_idx}: {e}")
        else:
            for client_idx in range(num_clients):
                for param_name, param_value in gpu_params[client_idx].items():
                    if 'lora_A' in param_name or 'lora_B' in param_name or 'lora_route' in param_name:
                        aggregated_results[client_idx][param_name] = param_value

        return aggregated_results
    
    def aggregation(self, route_aggregation: bool, params: List, lora_client_map=None) -> List[Dict]:
        if lora_client_map is not None:
            self.lora_client_map = lora_client_map

        if self.lora_client_map is None:
            raise ValueError("lora_client_map must be provided for aggregation after warmup phase")

        client_to_group = {}
        for group_idx, clients in self.lora_client_map.items():
            for client in clients:
                client_to_group[client] = group_idx

        gpu_params = [
            {k: v.to(self.device) for k, v in client_params.items()}
            for client_params in params
        ]
        num_clients = len(gpu_params)
        aggregated_results = [{} for _ in range(num_clients)]
        param_names = gpu_params[0].keys()

        for client_idx in range(num_clients):
            for param_name in param_names:

                if 'lora_route' in param_name:
                    if route_aggregation:
                        client_group = client_to_group.get(client_idx)
                        if client_group is not None:
                            group_indices = self.lora_client_map[client_group]
                            stacked_params = torch.stack([
                                gpu_params[i][param_name]
                                for i in group_indices
                            ]).to(self.device)
                            aggregated_results[client_idx][param_name] = stacked_params.mean(dim=0)
                        else:
                            aggregated_results[client_idx][param_name] = gpu_params[client_idx][param_name]
                    else:
                        aggregated_results[client_idx][param_name] = gpu_params[client_idx][param_name]

                elif 'lora_A' in param_name:
                    lora_idx = int(param_name.split('lora_A')[1][0])

                    group_indices = self.lora_client_map.get(str(lora_idx), [])
                    if not group_indices:
                        group_indices = self.lora_client_map.get(lora_idx, [])

                    if group_indices:
                        stacked_params = torch.stack([
                            gpu_params[i][param_name]
                            for i in group_indices if i < len(gpu_params) and param_name in gpu_params[i]
                        ]).to(self.device)
                        if stacked_params.size(0) > 0:
                            aggregated_results[client_idx][param_name] = stacked_params.mean(dim=0)
                        else:
                            aggregated_results[client_idx][param_name] = gpu_params[client_idx][param_name]
                    else:
                        aggregated_results[client_idx][param_name] = gpu_params[client_idx][param_name]

                elif 'lora_B' in param_name:
                    lora_idx = int(param_name.split('lora_B')[1][0])

                    group_indices = self.lora_client_map.get(str(lora_idx), [])
                    if not group_indices:
                        group_indices = self.lora_client_map.get(lora_idx, [])

                    if group_indices:
                        stacked_params = torch.stack([
                            gpu_params[i][param_name]
                            for i in group_indices if i < len(gpu_params) and param_name in gpu_params[i]
                        ]).to(self.device)
                        if stacked_params.size(0) > 0:
                            aggregated_results[client_idx][param_name] = stacked_params.mean(dim=0)
                        else:
                            aggregated_results[client_idx][param_name] = gpu_params[client_idx][param_name]
                    else:
                        aggregated_results[client_idx][param_name] = gpu_params[client_idx][param_name]
                else:
                    aggregated_results[client_idx][param_name] = gpu_params[client_idx][param_name]

        return aggregated_results


def load_B_only(proj_type, client_id, round_id, base_dir="./"):
    file_path = os.path.join(base_dir, f"{proj_type}_lora_B_client_{client_id}_{round_id}.npy")
    B = np.load(file_path).astype(np.float32)
    gc.collect()
    return B


def calculate_B_similarity_matrix(client_ids, proj_type, round_id, base_dir="./"):

    n_clients = len(client_ids)
    distance_matrix_B = np.zeros((n_clients, n_clients))
    
    client_pairs = [(i, j) for i in range(n_clients) for j in range(i+1, n_clients)]
    
    for (i, j) in tqdm(client_pairs, desc=f"Computing {proj_type} B matrix similarities"):
        client_i = client_ids[i]
        client_j = client_ids[j]
        
        try:
            B1 = load_B_only(proj_type, client_i, round_id, base_dir)
            B2 = load_B_only(proj_type, client_j, round_id, base_dir)
        except FileNotFoundError as e:
            print(f"Error loading matrices: {e}")
            print(f"Skipping client pair ({client_i}, {client_j})")
            continue
        
        B_similarities = []
        start_layer = min(0, B1.shape[0]-1) 
        
        for layer in range(start_layer, B1.shape[0]):
            b1_flat = B1[layer].flatten()
            b2_flat = B2[layer].flatten()
            
            dot_product_B = np.dot(b1_flat, b2_flat)
            norm_B1 = np.linalg.norm(b1_flat)
            norm_B2 = np.linalg.norm(b2_flat)
            cos_sim_B = dot_product_B / (norm_B1 * norm_B2 + 1e-8)
            
            B_similarities.append(1 - cos_sim_B)
            
            del b1_flat, b2_flat
        
        avg_distance_B = np.mean(B_similarities)
        
        distance_matrix_B[i, j] = distance_matrix_B[j, i] = avg_distance_B
        
        del B1, B2
        gc.collect()
    
    return distance_matrix_B


def visualize_clustering(combined_matrix, client_ids, labels, output_dir, round_idx):
    os.makedirs(output_dir, exist_ok=True)
    
    condensed_dist = squareform(combined_matrix)
    Z = linkage(condensed_dist, method='average')
    plt.figure(figsize=(12, 8))
    dendrogram(Z, labels=[str(cid) for cid in client_ids], leaf_rotation=90)
    plt.title('Client Clustering for LoRA Groups')
    plt.xlabel('Client')
    plt.ylabel('Distance (1 - Cosine Similarity)')
    plt.savefig(os.path.join(output_dir, f"lora_clustering_round_{round_idx}.png"))
    plt.close()
    
    plt.figure(figsize=(10, 8))
    plt.imshow(combined_matrix, cmap='viridis')
    plt.colorbar(label='Distance (1 - Cosine Similarity)')
    plt.xticks(range(len(client_ids)), [str(cid) for cid in client_ids], rotation=45)
    plt.yticks(range(len(client_ids)), [str(cid) for cid in client_ids])
    plt.title('B Matrix Distance Between Clients')
    plt.savefig(os.path.join(output_dir, f"lora_distance_matrix_round_{round_idx}.png"))
    plt.close()
    
    plt.figure(figsize=(10, 6))
    scatter = plt.scatter(range(len(client_ids)), [0] * len(client_ids), c=labels, cmap='viridis', 
                         s=100, marker='o')
    plt.colorbar(scatter, label='Cluster ID')
    plt.xticks(range(len(client_ids)), [str(cid) for cid in client_ids])
    plt.yticks([])
    plt.title('Client Cluster Assignments')
    plt.savefig(os.path.join(output_dir, f"lora_clusters_round_{round_idx}.png"))
    plt.close()


def compute_lora_client_map(clients, round_idx, personal_dir="./", max_clusters=10):
    print("Computing LoRA client map based on B matrix similarity...")
    
    from sklearn.metrics import silhouette_score, davies_bouldin_score
    
    proj_types = ["query", "value"]
    
    client_ids = [client.client_id for client in clients]
    n_clients = len(client_ids)
    
    log_file = os.path.join(personal_dir, "training_log.txt")
    with open(log_file, "a") as f:
        f.write("\nStarting LoRA client mapping calculation...\n")
    
    all_distance_matrices = {}
    
    param_dir = os.path.join(personal_dir, "lora_params")
    
    for proj_type in proj_types:
        with open(log_file, "a") as f:
            f.write(f"Calculating {proj_type} B matrix similarities...\n")
        
        print(param_dir)
        dist_matrix_B = calculate_B_similarity_matrix(client_ids, proj_type, round_idx, base_dir=param_dir)
        all_distance_matrices[proj_type] = dist_matrix_B
        
        np.save(os.path.join(personal_dir, f"{proj_type}_distance_matrix_round_{round_idx}.npy"), dist_matrix_B)
    
    with open(log_file, "a") as f:
        f.write("Combining distance matrices...\n")
    
    combined_matrix = np.zeros_like(all_distance_matrices[proj_types[0]])
    for proj_type in proj_types:
        combined_matrix += all_distance_matrices[proj_type]
    combined_matrix /= len(proj_types)
    
    np.save(os.path.join(personal_dir, f"combined_distance_matrix_round_{round_idx}.npy"), combined_matrix)
    
    with open(log_file, "a") as f:
        f.write(f"Evaluating optimal cluster number from 1 to {max_clusters}...\n")
    
    silhouette_scores = []
    davies_bouldin_scores = []
    cluster_labels_list = []
    
    from sklearn.manifold import MDS
    embedding = MDS(n_components=2, dissimilarity='precomputed', random_state=42)
    embedded_coords = embedding.fit_transform(combined_matrix)
    
    plt.figure(figsize=(12, 6))
    
    for n_clusters in range(2, max_clusters + 1):
        model = AgglomerativeClustering(
            n_clusters=n_clusters,
            metric='precomputed',
            linkage='average'
        )
        
        labels = model.fit_predict(combined_matrix)
        cluster_labels_list.append(labels)
        
        try:
            sil_score = silhouette_score(
                combined_matrix, 
                labels, 
                metric='precomputed'
            )
            silhouette_scores.append(sil_score)
            
            db_score = davies_bouldin_score(embedded_coords, labels)
            davies_bouldin_scores.append(db_score)
            
            with open(log_file, "a") as f:
                f.write(f"  Clusters={n_clusters}, Silhouette Score={sil_score:.4f}, Davies-Bouldin Score={db_score:.4f}\n")
                
            print(f"Clusters={n_clusters}, Silhouette Score={sil_score:.4f}, Davies-Bouldin Score={db_score:.4f}")
        except Exception as e:
            with open(log_file, "a") as f:
                f.write(f"  Error with n_clusters={n_clusters}: {str(e)}\n")
            silhouette_scores.append(-1) 
            davies_bouldin_scores.append(float('inf'))
    
    plt.plot(range(2, max_clusters + 1), silhouette_scores, 'b-o')
    plt.xlabel('Number of Clusters')
    plt.ylabel('Silhouette Score')
    plt.title('Silhouette Score vs. Number of Clusters')
    plt.grid(True)
    plt.savefig(os.path.join(personal_dir, f"silhouette_scores_round_{round_idx}.png"))
    plt.close()
    
    plt.figure(figsize=(12, 6))
    plt.plot(range(2, max_clusters + 1), davies_bouldin_scores, 'r-o')
    plt.xlabel('Number of Clusters')
    plt.ylabel('Davies-Bouldin Score (lower is better)')
    plt.title('Davies-Bouldin Score vs. Number of Clusters')
    plt.grid(True)
    plt.savefig(os.path.join(personal_dir, f"davies_bouldin_scores_round_{round_idx}.png"))
    plt.close()
    
    if max(silhouette_scores) - min(silhouette_scores) > 0:
        norm_silhouette = [(s - min(silhouette_scores)) / (max(silhouette_scores) - min(silhouette_scores)) 
                          for s in silhouette_scores]
    else:
        norm_silhouette = [1.0 for _ in silhouette_scores]
    
    if max(davies_bouldin_scores) - min(davies_bouldin_scores) > 0:
        norm_davies = [1 - ((s - min(davies_bouldin_scores)) / (max(davies_bouldin_scores) - min(davies_bouldin_scores)))
                      for s in davies_bouldin_scores]
    else:
        norm_davies = [1.0 for _ in davies_bouldin_scores]
    
    combined_scores = [s for s, d in zip(norm_silhouette, norm_davies)]
    
    best_idx = combined_scores.index(max(combined_scores))
    optimal_n_clusters = best_idx + 2  
    
    optimal_labels = cluster_labels_list[best_idx]
    
    with open(log_file, "a") as f:
        f.write(f"Optimal number of clusters determined: {optimal_n_clusters}\n")
    print(f"Optimal number of clusters: {optimal_n_clusters}")
    
    lora_client_map = {}
    for cluster_id in range(optimal_n_clusters):
        client_indices = [i for i, label in enumerate(optimal_labels) if label == cluster_id]
        lora_client_map[cluster_id] = client_indices
    
    print("\nClustering Results:")
    with open(log_file, "a") as f:
        f.write("\nClustering Results:\n")
        for cluster_id, cluster_clients in lora_client_map.items():
            client_names = [client_ids[i] for i in cluster_clients]
            cluster_info = f"Cluster {cluster_id} (LoRA {cluster_id}): {', '.join(map(str, client_names))}"
            print(cluster_info)
            f.write(cluster_info + "\n")
    
    cluster_file = os.path.join(personal_dir, f"lora_client_map_round_{round_idx}.json")
    with open(cluster_file, 'w') as f:
        json.dump(lora_client_map, f, indent=2)
    
    with open(os.path.join(personal_dir, f"optimal_n_clusters_round_{round_idx}.json"), 'w') as f:
        json.dump({"optimal_n_clusters": optimal_n_clusters}, f, indent=2)
    
    with open(log_file, "a") as f:
        f.write("Generating visualization plots...\n")
    
    visualize_clustering(combined_matrix, client_ids, optimal_labels, personal_dir, round_idx)
    
    plt.figure(figsize=(10, 8))
    for cluster_id in range(optimal_n_clusters):
        cluster_points = [i for i, label in enumerate(optimal_labels) if label == cluster_id]
        plt.scatter(
            embedded_coords[cluster_points, 0],
            embedded_coords[cluster_points, 1],
            label=f'Cluster {cluster_id}',
            s=100
        )
    
    for i, (x, y) in enumerate(embedded_coords):
        plt.annotate(str(client_ids[i]), (x, y), textcoords="offset points", xytext=(0,10), ha='center')
    
    plt.title(f'Client Clusters in 2D Space (n_clusters={optimal_n_clusters})')
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')
    plt.legend()
    plt.savefig(os.path.join(personal_dir, f"cluster_embedding_round_{round_idx}.png"))
    plt.close()
    
    with open(log_file, "a") as f:
        f.write("LoRA client mapping calculation completed.\n")
    
    return lora_client_map, optimal_n_clusters


def train_federated(
    dummy,
    clients,
    server,
    global_rounds,
    local_epochs,
    output_dir,
    lr=3e-4,
    round_warmup=1,
    max_clusters=10,
    task_info=None,
    client_datasets=None
):
    import os
    import json
    import datetime
    from tqdm import tqdm
    
    personal_dir = os.path.join(output_dir, "proposed_m2")
    os.makedirs(personal_dir, exist_ok=True)
    
    current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    log_file = os.path.join(personal_dir, "training_log.txt")
    with open(log_file, "w") as f:
        f.write(f"[{current_time}] Starting Federated Training with Dummy Client\n")
        f.write(f"Total Rounds: {global_rounds}, Local Epochs: {local_epochs}, Warmup Rounds: {round_warmup}\n")
        f.write("-" * 50 + "\n")
    
    warmup_clients = clients
    
    all_client_scores = {client.client_id: [] for client in warmup_clients}
    
    lora_client_map = None
    saved_params = None
    optimal_n_clusters = None
    clustered_clients = None
    aggregated_params = None
    
    for round_idx in tqdm(range(global_rounds), desc="Global Rounds"):
        current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        with open(log_file, "a") as f:
            f.write(f"\n[{current_time}] Starting Global Round {round_idx + 1}/{global_rounds}\n")
        
        print(f"\nGlobal Round {round_idx + 1}/{global_rounds}")
        
        if round_idx < round_warmup:
            print(f"Running warmup phase (round {round_idx+1}/{round_warmup})")
            
            with open(log_file, "a") as f:
                f.write("Starting dummy client warmup\n")
            
            dummy.load_model()
            dummy.local_training(lr=lr, epochs=local_epochs, batch_size=128)
            dummy.unload_model()
            
            client_params = []
            for client in tqdm(warmup_clients, desc="Client Training (Warmup)"):
                with open(log_file, "a") as f:
                    f.write(f"Training Warmup Client {client.client_id} ({client.task_name})...\n")
                
                client.load_model()
                if round_idx > 0 and aggregated_params is not None:
                    client.load_params(aggregated_params[client.client_id])
                
                client.local_training(lr=lr, epochs=local_epochs, batch_size=128)
                params = client.get_lora_params_and_save_by_module(round_id=round_idx, personal_dir=personal_dir)
                client_params.append(params['params'])
                
                if (round_idx + 1) == round_warmup:
                    if saved_params is None:
                        saved_params = {}
                    saved_params[client.client_id] = params['params']
                
                client.unload_model()
            
            with open(log_file, "a") as f:
                f.write("Starting Server Aggregation (Warmup)...\n")
            
            agg_lora_client_map = None
            if (round_idx + 1) == round_warmup:
                with open(log_file, "a") as f:
                    f.write("Computing LoRA client mapping for clustering\n")
                
                lora_client_map, optimal_n_clusters = compute_lora_client_map(
                    warmup_clients, 
                    round_idx, 
                    personal_dir,
                    max_clusters=max_clusters
                )
                agg_lora_client_map = lora_client_map
                
                with open(log_file, "a") as f:
                    f.write(f"LoRA client mapping computed: {lora_client_map}\n")
                    f.write(f"Optimal number of clusters: {optimal_n_clusters}\n")
                
                if task_info is not None and client_datasets is not None:
                    clustered_clients = []
                    for client_id in range(len(warmup_clients)):
                        client_task = task_info[client_id]["task_name"]
                        num_labels = task_info[client_id]["num_labels"]
                        
                        client_cluster = None
                        for cluster_id, cluster_clients in lora_client_map.items():
                            if client_id in cluster_clients:
                                client_cluster = int(cluster_id)
                                break
                        
                        if client_cluster is None:
                            print(f"Warning: Client {client_id} not found in any cluster. Assigning to cluster 0.")
                            client_cluster = 0
                        
                        client = Client(
                            client_id=client_id,
                            task_name=client_task,
                            tokenizer=warmup_clients[client_id].tokenizer,
                            model_name=warmup_clients[client_id].model_name,
                            num_clients=len(warmup_clients),
                            rank=4,
                            lora_n=optimal_n_clusters,
                            adaptive=True,  
                            cache_path=output_dir,
                            idx=client_cluster 
                        )
                        
                        client.set_dataset(client_datasets[client_id], num_labels)
                        
                        clustered_clients.append(client)
                    
                    clients = clustered_clients
                    
                    server = Server(clients_num=len(clients))
                    
                    with open(log_file, "a") as f:
                        f.write(f"Initialized {len(clients)} clustered clients with {optimal_n_clusters} LoRA modules\n")
            
            aggregated_params = server.aggregation_warmup(
                route_aggregation=True,
                params=client_params,
                lora_client_map=agg_lora_client_map
            )
            
            if (round_idx + 1) % 1 == 0:
                with open(log_file, "a") as f:
                    f.write(f"Performing warmup evaluation at round {round_idx + 1}\n")
                
                print(f"\nWarmup Round {round_idx + 1} Evaluation Scores:")
                round_scores = {}
                
                for client in warmup_clients:
                    client_id = client.client_id
                    client.load_model()
                    client.load_params(aggregated_params[client_id])
                    
                    metrics = client.evaluate_model()
                    all_client_scores[client_id].append(metrics)
                    round_scores[client_id] = metrics
                    
                    client.unload_model()
                
                summary_file = os.path.join(personal_dir, f"round_summary_{round_idx+1}.json")
                with open(summary_file, 'w') as f:
                    json.dump(round_scores, f, indent=2)
        
        else:
            print(f"Running clustered training phase (round {round_idx+1-round_warmup}/{global_rounds-round_warmup})")
            
            if clients is None:
                with open(log_file, "a") as f:
                    f.write("ERROR: Clustered clients not initialized. This should not happen.\n")
                raise RuntimeError("Clustered clients not initialized")
            
            if round_idx == round_warmup:
                with open(log_file, "a") as f:
                    f.write("Transitioning from warmup to clustered training\n")
                    f.write(f"LoRA client mapping: {lora_client_map}\n")
                
                for client in clients:
                    client.load_model()
                    
                    if client.client_id in saved_params:
                        warmed_params = {}
                        client_group = client.idx
                        
                        for name, param in saved_params[client.client_id].items():
                            if 'lora_A0' in name and client_group is not None:
                                new_name = name.replace('lora_A0', f'lora_A{client_group}')
                                warmed_params[new_name] = param
                            elif 'lora_B0' in name and client_group is not None:
                                new_name = name.replace('lora_B0', f'lora_B{client_group}')
                                warmed_params[new_name] = param
                            elif 'lora_route' in name:
                                continue
                            else:
                                warmed_params[name] = param
                        
                        client.local_model.load_state_dict(warmed_params, strict=False)
                    
                    client.unload_model()
            
            client_params = []
            for client in tqdm(clients, desc="Client Training (Clustered)"):
                with open(log_file, "a") as f:
                    f.write(f"Training Clustered Client {client.client_id} ({client.task_name})...\n")
                
                client.load_model()
                if round_idx > round_warmup:
                    client.load_params(aggregated_params[client.client_id])
                
                client.local_training(lr=lr, epochs=local_epochs, batch_size=128, 
                                     lora_client_map=lora_client_map)
                params = client.get_lora_params()
                client_params.append(params['params'])
                
                client.unload_model()
            
            with open(log_file, "a") as f:
                f.write("Starting Server Aggregation (Clustered)...\n")
            
            aggregated_params = server.aggregation(
                route_aggregation=True,
                params=client_params,
                lora_client_map=lora_client_map
            )
            
            if (round_idx + 1) % 1 == 0:
                with open(log_file, "a") as f:
                    f.write(f"Performing clustered evaluation at round {round_idx + 1}\n")
                
                print(f"\nClustered Round {round_idx + 1} Evaluation Scores:")
                round_scores = {}
                
                for client in clients:
                    client_id = client.client_id
                    client.load_model()
                    client.load_params(aggregated_params[client_id])
                    
                    metrics = client.evaluate_model()
                    all_client_scores[client_id].append(metrics)
                    round_scores[client_id] = metrics
                    
                    client.unload_model()
                
                summary_file = os.path.join(personal_dir, f"round_summary_{round_idx+1}.json")
                with open(summary_file, 'w') as f:
                    json.dump(round_scores, f, indent=2)
        
        with open(log_file, "a") as f:
            current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            f.write(f"[{current_time}] Completed Global Round {round_idx + 1}/{global_rounds}\n")
            f.write("-" * 50 + "\n")
    
    with open(os.path.join(personal_dir, "training_history.json"), 'w') as f:
        json.dump({
            "client_scores": all_client_scores,
            "optimal_n_clusters": optimal_n_clusters,
            "lora_client_map": {str(k): v for k, v in lora_client_map.items()} if lora_client_map else None
        }, f, indent=2)
    
    return all_client_scores


def main():
    task_name_list = ["sst2", "sst2", "sst2", "sst2", "qnli", "qnli", "qnli", "qnli", "mrpc", "mrpc", "mrpc", "mrpc", "qqp", "qqp", "qqp", "qqp"]
    model_name = "roberta-large"  
    output_dir = f"./output/roberta_large_multi_task_federated_16"
    client_num = len(task_name_list)
    global_rounds = 25
    local_epochs = 2
    alpha = 100000
    train_samples_per_client = 1000
    test_samples_per_client = 200
    max_clusters = 2 
    
    print(f"Running federated learning with multi-task datasets: {task_name_list}")
    print(f"Number of clients: {client_num}")
    
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    client_datasets, task_info = partition_multi_task_dataset(
        task_name_list=task_name_list,
        tokenizer=tokenizer,
        alpha=alpha,
        train_samples_per_client=train_samples_per_client,
        test_samples_per_client=test_samples_per_client,
        seed=42
    )
    
    dummy_task = task_name_list[0]
    dummy_num_labels = task_info[0]["num_labels"]
    
    dummy = WarmupClient(
        client_id=client_num,
        task_name=dummy_task,
        tokenizer=tokenizer,
        model_name=model_name,
        num_clients=client_num,
        rank=4,
        cache_path=output_dir
    )
    dummy.set_dataset(client_datasets[0], dummy_num_labels)
    
    warmup_clients = []
    for client_id in range(client_num):
        client_task = task_info[client_id]["task_name"]
        num_labels = task_info[client_id]["num_labels"]
        
        client = WarmupClient(
            client_id=client_id,
            task_name=client_task,
            tokenizer=tokenizer,
            model_name=model_name,
            num_clients=client_num,
            rank=4,
            cache_path=output_dir
        )
        client.set_dataset(client_datasets[client_id], num_labels)
        warmup_clients.append(client)
    
    warmup_server = Server(clients_num=len(warmup_clients))
    
    train_result = train_federated(
        dummy=dummy,
        clients=warmup_clients,
        server=warmup_server,
        global_rounds=global_rounds,
        local_epochs=local_epochs,
        output_dir=output_dir,
        lr=3e-3,
        round_warmup=5,
        max_clusters=max_clusters,
        task_info=task_info,
        client_datasets=client_datasets
    )
    
    print("\nTraining completed!")
    print("Final Evaluation Scores for each client:", train_result)


if __name__ == "__main__":
    main()

