import os
import argparse
import torch
import numpy as np
import random
from torch.utils.data import DataLoader, ConcatDataset
import time

from utils.dataloader import generate_federated_incremental_dataloader
from methods.Li_FIL import LiFILServer, LiFILClient
from methods.vanilla import FedAvgServer, FedAvgClient, FedProxClient
from utils.utils import Logger

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--method', type=str, default='Li-FIL', help='Federated learning method: Li-FIL, FedAvg, FedProx')
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='Dataset: CIFAR10, CIFAR100, FashionMNIST, Tiny-Imagenet')
    parser.add_argument('--n_clients', type=int, default=10, help='Number of clients')
    parser.add_argument('--fraction', type=float, default=0.6, help='Fraction of clients to participate in each round')
    parser.add_argument('--n_tasks', type=int, default=5, help='Number of incremental tasks')
    parser.add_argument('--n_rounds', type=int, default=20, help='Number of communication rounds per task')
    parser.add_argument('--local_epochs', type=int, default=10, help='Number of local epochs per round')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for client optimizer')
    parser.add_argument('--dirichlet_alpha', type=float, default=0.5, help='Alpha for Dirichlet distribution')
    parser.add_argument('--min_samples_per_class', type=int, default=2, help='Minimum samples per class for each client')
    parser.add_argument('--gpu_id', type=int, default=0, help='GPU ID')
    # Model and Li-FIL specific args
    parser.add_argument('--latent_dim', type=int, default=512, help='Latent dimension for features and CVAE')
    parser.add_argument('--lambda1', type=float, default=0.3333, help='Weight for task loss (Li-FIL), will be normalized')
    parser.add_argument('--lambda2', type=float, default=0.3333, help='Weight for replay loss (Li-FIL), will be normalized')
    parser.add_argument('--lambda3', type=float, default=0.3333, help='Weight for FDA loss (Li-FIL), will be normalized')
    parser.add_argument('--beta', type=float, default=0.5, help='Trade-off between global (MMD) and local (contrastive) alignment (Li-FIL)')
    parser.add_argument('--temperature', type=float, default=2.0, help='Temperature parameter T > 1 for confidence calibration (Li-FIL)')
    parser.add_argument('--conf_thresh', type=float, default=0.55, help='Confidence threshold for feature curation (Li-FIL)')
    parser.add_argument('--fedprox_mu', type=float, default=0.01, help='Proximal term coefficient for FedProx')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    
    args = parser.parse_args()
    return args

def set_seed(seed):
    """Set random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Ensure deterministic CUDA operations (may affect performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Random seed set to {seed} for reproducibility")

def get_dataset_config(dataset_name):
    """Returns number of classes and input channels for a given dataset."""
    if dataset_name == 'FashionMNIST':
        return 10, 1
    elif dataset_name == 'CIFAR10':
        return 10, 3
    elif dataset_name == 'CIFAR100':
        return 100, 3
    elif dataset_name == 'Tiny-Imagenet':
        return 200, 3
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

def evaluate(model, test_loader, device):
    """Evaluates the model on the given test loader."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

def main():
    args = get_args()
    
    set_seed(args.seed)
    
    print("Starting experiment with the following arguments:")
    print(args)
    
    DEVICE = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
    logger = Logger(args)

    # 1. Load Data
    print("\n[Phase 1] Loading and partitioning data...")
    tasks_dataloaders = generate_federated_incremental_dataloader(
        dataset_name=args.dataset,
        n_tasks=args.n_tasks,
        n_clients=args.n_clients,
        alpha=args.dirichlet_alpha,
        batch_size=args.batch_size,
        min_samples_per_class=args.min_samples_per_class,
    )
    
    # 2. Initialize Models, Server, and Clients
    print("\n[Phase 2] Initializing server and clients...")
    num_classes, input_channels = get_dataset_config(args.dataset)
    
    model_args = {'latent_dim': args.latent_dim, 'num_classes': num_classes, 'input_channels': input_channels}
    
    server = None
    clients = []
    
    if args.method == 'Li-FIL':
        server = LiFILServer(
            model_args=model_args,
            latent_dim=args.latent_dim, 
            num_classes=num_classes, 
            device=DEVICE
        )
        clients = [LiFILClient(i, model_args=model_args, device=DEVICE, lr=args.lr) for i in range(args.n_clients)]
    elif args.method == 'FedAvg':
        server = FedAvgServer(model_args=model_args, device=DEVICE)
        clients = [FedAvgClient(i, model_args=model_args, device=DEVICE, lr=args.lr) for i in range(args.n_clients)]
    elif args.method == 'FedProx':
        server = FedAvgServer(model_args=model_args, device=DEVICE) # FedProx uses the same server logic as FedAvg
        clients = [FedProxClient(i, model_args=model_args, device=DEVICE, lr=args.lr, mu=args.fedprox_mu) for i in range(args.n_clients)]
    else:
        raise NotImplementedError(f"Method {args.method} is not implemented.")
        
    # Li-FIL specific hyperparameters
    # Normalize lambda weights to sum to 1 (as per Equation 11)
    lambda_sum = args.lambda1 + args.lambda2 + args.lambda3
    train_args = {
        'lambda1': args.lambda1 / lambda_sum if lambda_sum > 0 else args.lambda1,
        'lambda2': args.lambda2 / lambda_sum if lambda_sum > 0 else args.lambda2,
        'lambda3': args.lambda3 / lambda_sum if lambda_sum > 0 else args.lambda3,
        'beta': args.beta,  # Trade-off between global (MMD) and local (contrastive) alignment (Equation 10)
        'k_neighbors': 5,  # Number of positive neighbors for contrastive loss
        'contrastive_temperature': 0.1,  # Temperature parameter tau for contrastive loss
        'local_epochs': args.local_epochs
    }
    pmm_args = {'alpha': 2.0, 'l1_clip': 1.5, 'epsilon': 1.0}
    temperature = args.temperature  # Temperature parameter T > 1 for confidence calibration (Equation 2)

    # 3. Incremental Training Loop
    print("\n[Phase 3] Starting incremental training loop...")
    all_task_test_datasets = []
    
    # Create directory for experimental data and checkpoints
    exp_data_dir = os.path.join('.', 'exp_data')
    checkpoint_dir = os.path.join('.', 'checkpoints')
    for d in [exp_data_dir, checkpoint_dir]:
        if not os.path.exists(d):
            os.makedirs(d)

    for task_id, task_data in enumerate(tasks_dataloaders):
        round_start_time = time.time()
        print(f"\n{'='*20} Task {task_id + 1}/{args.n_tasks} {'='*20}")
        
        client_train_loaders = task_data['train_loaders']

        for comm_round in range(args.n_rounds):
            print(f"\n--- Task {task_id + 1}, Communication Round {comm_round + 1}/{args.n_rounds} ---")
            
            # --- Client Selection ---
            n_selected_clients = max(1, int(args.fraction * args.n_clients))
            selected_client_indices = np.random.choice(range(args.n_clients), n_selected_clients, replace=False)
            print(f"Selected {n_selected_clients} clients for this round: {selected_client_indices}")
            
            # --- Server-Client Interaction ---
            if args.method == 'Li-FIL':
                # --- Li-FIL Training Cycle ---
                # 1. Server distributes global model and virtual features
                global_model_state = server.get_global_model().state_dict()
                virtual_features, virtual_labels = None, None
                if task_id > 0: 
                    print("Server: Generating virtual features for replay...")
                    # Compute max client samples: N = max({N_1^{t+1}, N_2^{t+1}, ..., N_K^{t+1}})
                    max_client_samples = max([
                        len(loader.dataset) for loader in client_train_loaders 
                        if loader is not None
                    ]) if any(loader is not None for loader in client_train_loaders) else 0
                    virtual_features, virtual_labels = server.generate_virtual_features(
                        max_client_samples=max_client_samples, 
                        current_task_id=task_id
                    )

                # 2. Clients perform local training
                print("Clients: Performing local training...")
                client_sample_counts = []  # For weighted aggregation
                client_id_to_count = {}  # Map client_id to sample count
                for client_id in selected_client_indices:
                    client = clients[client_id]
                    client.set_weights(global_model_state) # Download global model
                    if client_train_loaders[client_id]:
                        # A single epoch of training per communication round
                        client.train_task(client_train_loaders[client_id], virtual_features, virtual_labels, train_args)
                        count = len(client_train_loaders[client_id].dataset)
                        client_sample_counts.append(count)
                        client_id_to_count[client_id] = count
                    else:
                        client_sample_counts.append(0)
                        client_id_to_count[client_id] = 0
                
                # 3. Server aggregates client models (weighted by sample counts, Equation 1)
                print("Server: Aggregating client models...")
                client_models = [clients[i].model for i in selected_client_indices if client_train_loaders[i] is not None]
                selected_sample_counts = [client_id_to_count[i] 
                                         for i in selected_client_indices if client_train_loaders[i] is not None]
                server.aggregate_weights(client_models, client_sample_counts=selected_sample_counts)
                
            elif args.method in ['FedAvg', 'FedProx']:
                # --- FedAvg/FedProx Training Cycle ---
                print("Server: Distributing global model to clients...")
                global_model_state = server.get_global_model().state_dict()
                for client_id in selected_client_indices:
                    clients[client_id].set_weights(global_model_state)

                print("Clients: Performing local training...")
                for client_id in selected_client_indices:
                    client = clients[client_id]
                    if client_train_loaders[client_id]:
                        client.train_task(client_train_loaders[client_id], epochs=args.local_epochs)
                
                print("Server: Aggregating client models...")
                client_models = [clients[i].model for i in selected_client_indices if client_train_loaders[i] is not None]
                selected_sample_counts = [len(client_train_loaders[i].dataset) 
                                         for i in selected_client_indices if client_train_loaders[i] is not None]
                server.aggregate_weights(client_models, client_sample_counts=selected_sample_counts)

        # --- Post-Task Operations ---
        if args.method == 'Li-FIL':
            # IMPORTANT: All clients must sync with the final aggregated global model 
            # before contributing features to ensure the server's attack/CVAE models match.
            global_model_state = server.get_global_model().state_dict()
            for client in clients:
                client.set_weights(global_model_state)

            print("\nClients: Contributing features to the server at the end of the task...")
            collected_features, collected_labels = [], []
            
            # Data structure for saving experimental baselines
            task_exp_data = {
                'raw_features': [], 'raw_labels': [], 'raw_images': [],
                'mixup_only': [], 'dp_only': [], 'full_lifil': [], 'mixed_labels': []
            }

            for client_id, client in enumerate(clients): # All clients contribute at the end of the task
                if client_train_loaders[client_id]:
                    # 1. Get all baselines and features for both training and experimental post-processing
                    baselines = client.contribute_features(
                        client_train_loaders[client_id], 
                        conf_thresh=args.conf_thresh,
                        temperature=temperature,
                        pmm_args=pmm_args,
                        return_baselines=True
                    )
                    
                    if baselines is not None:
                        # Standard privacy features for generator update
                        private_f = baselines['full_lifil']
                        private_l = baselines['mixed_labels']
                        collected_features.append(private_f.cpu())
                        collected_labels.append(private_l.cpu())
                        
                        # Store all baselines for saving
                        for key in task_exp_data.keys():
                            task_exp_data[key].append(baselines[key].cpu())

            # Save experimental data to disk
            if task_exp_data['raw_features']:
                # Concatenate all clients' data for this task
                for key in task_exp_data.keys():
                    task_exp_data[key] = torch.cat(task_exp_data[key])
                
                save_path = os.path.join(exp_data_dir, f'task_{task_id}_privacy_exp.pt')
                torch.save(task_exp_data, save_path)
                print(f"Server: Experimental data for Task {task_id+1} saved to {save_path}")
                
                # Save model checkpoints for MIA/FI attacks
                model_ckpt = {
                    'resnet': server.get_global_model().state_dict(),
                    'generator': server.generator.state_dict()
                }
                ckpt_path = os.path.join(checkpoint_dir, f'task_{task_id}_models.pt')
                torch.save(model_ckpt, ckpt_path)
                print(f"Server: Model checkpoints for Task {task_id+1} saved to {ckpt_path}")

            if collected_features:
                print("Server: Aggregating features and updating generator...")
                all_features = torch.cat(collected_features)
                all_labels = torch.cat(collected_labels)
                server.update_generator(all_features, all_labels)
            else:
                print("Server: No features collected in this task.")
            
            # Update server's knowledge of seen classes for future generation (end of task)
            task_classes = torch.unique(torch.tensor(task_data['test_loader'].dataset.labels))
            server.seen_classes.update(task_classes.numpy())
            print(f"Server: Knowledge updated for Task {task_id+1}. Now knows classes: {sorted(list(server.seen_classes))}")
        
        # --- Evaluation (at the end of each task) ---
        print(f"\n--- Evaluating after Task {task_id + 1} ---")
        all_task_test_datasets.append(task_data['test_loader'].dataset)
        combined_test_dataset = ConcatDataset(all_task_test_datasets)
        combined_test_loader = DataLoader(combined_test_dataset, batch_size=args.batch_size, shuffle=False)
        
        # Evaluate the global model on the server for all methods
        model_to_evaluate = server.get_global_model()
        avg_acc = evaluate(model_to_evaluate, combined_test_loader, DEVICE)
        print(f"\n>>> Task {task_id + 1} Final Accuracy: {avg_acc:.2f}% <<<")

        round_end_time = time.time()
        round_duration = round_end_time - round_start_time
        logger.log_round(task_id + 1, round_duration, avg_acc)

    print("\nExperiment finished.")
    logger.close()

if __name__ == '__main__':
    main()
