import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
from models.resnet import ResNet18
from models.cvae import BetaCVAE

# Helper functions from the paper

def mmd_rbf(X, Y, gamma=1.0):
    """
    Computes the squared Maximum Mean Discrepancy (MMD^2) with a Gaussian (RBF) kernel.
    Implements Equation (8) from the paper:
    MMD^2 = E[k(x,x')] + E[k(y,y')] - 2*E[k(x,y)]
    where k(x,y) = exp(-gamma * ||x-y||^2) is the RBF kernel.
    
    Args:
        X: Tensor of shape [N_X, d] - features from new task
        Y: Tensor of shape [N_Y, d] - synthetic features from past tasks
        gamma: RBF kernel bandwidth parameter
    
    Returns:
        MMD^2 value (scalar)
    """
    # Compute pairwise dot products
    XX = torch.matmul(X, X.t())  # Shape: [N_X, N_X], element [i,j] = x_i^T x_j
    YY = torch.matmul(Y, Y.t())   # Shape: [N_Y, N_Y], element [i,j] = y_i^T y_j
    XY = torch.matmul(X, Y.t())   # Shape: [N_X, N_Y], element [i,j] = x_i^T y_j

    # Get squared norms (diagonal elements)
    rx = XX.diag().unsqueeze(1)  # Shape: [N_X, 1], element [i,0] = ||x_i||^2
    ry = YY.diag().unsqueeze(0)  # Shape: [1, N_Y], element [0,j] = ||y_j||^2

    # Compute pairwise squared Euclidean distances using: ||a-b||^2 = ||a||^2 + ||b||^2 - 2*a^T*b
    dxx = rx + rx.t() - 2. * XX  # Shape: [N_X, N_X], element [i,j] = ||x_i - x_j||^2
    dyy = ry.t() + ry - 2. * YY  # Shape: [N_Y, N_Y], element [i,j] = ||y_i - y_j||^2
    dxy = rx + ry - 2. * XY      # Shape: [N_X, N_Y], element [i,j] = ||x_i - y_j||^2

    # Apply RBF kernel: k(x,y) = exp(-gamma * ||x-y||^2)
    K_xx = torch.exp(-gamma * dxx)  # Shape: [N_X, N_X]
    K_yy = torch.exp(-gamma * dyy)  # Shape: [N_Y, N_Y]
    K_xy = torch.exp(-gamma * dxy)   # Shape: [N_X, N_Y]

    # Compute MMD^2 = E[k(x,x')] + E[k(y,y')] - 2*E[k(x,y)]
    # Note: .mean() computes the mean over all elements (including diagonal)
    # For MMD, we want the mean over all pairs, which is what .mean() gives us
    mmd_squared = K_xx.mean() + K_yy.mean() - 2. * K_xy.mean()
    
    return mmd_squared

def contrastive_alignment_loss(features_new, features_replay, k_neighbors=5, temperature=0.1):
    """
    Computes the contrastive alignment loss (Equation 9).
    For each sample feature of the new task, selects top-K nearest neighbors in the synthetic feature pool.
    
    Args:
        features_new: Features from new task [N_new, d]
        features_replay: Synthetic features from past tasks [N_replay, d]
        k_neighbors: Number of positive neighbors (K)
        temperature: Temperature parameter tau
    
    Returns:
        Contrastive loss value
    """
    if features_new.shape[0] == 0 or features_replay.shape[0] == 0:
        return torch.tensor(0.0, device=features_new.device)
    
    # Normalize features for cosine similarity
    features_new_norm = F.normalize(features_new, p=2, dim=1)
    features_replay_norm = F.normalize(features_replay, p=2, dim=1)
    
    # Compute similarity matrix: [N_new, N_replay]
    similarity_matrix = torch.matmul(features_new_norm, features_replay_norm.t())
    
    # For each new feature, find top-K nearest neighbors (positives)
    # and the rest as negatives
    topk_values, topk_indices = torch.topk(similarity_matrix, k=min(k_neighbors, features_replay.shape[0]), dim=1)
    
    loss_cont = 0.0
    for i in range(features_new.shape[0]):
        # Positive set: top-K nearest neighbors
        pos_indices = topk_indices[i]
        pos_similarities = similarity_matrix[i, pos_indices]
        
        # Negative set: all other features
        neg_mask = torch.ones(features_replay.shape[0], dtype=torch.bool, device=features_replay.device)
        neg_mask[pos_indices] = False
        neg_similarities = similarity_matrix[i, neg_mask]
        
        if neg_similarities.shape[0] == 0:
            continue
        
        # Compute contrastive loss for this sample (Equation 9)
        # L_cont = log[1 + sum(exp(sim(f, f^-)/tau)) / sum(exp(sim(f, f^+)/tau))]
        neg_sum = torch.sum(torch.exp(neg_similarities / temperature))
        pos_sum = torch.sum(torch.exp(pos_similarities / temperature))
        
        if pos_sum > 0:
            loss_cont += torch.log(1 + neg_sum / pos_sum)
    
    return loss_cont / features_new.shape[0] if features_new.shape[0] > 0 else torch.tensor(0.0, device=features_new.device)

def privacy_enhanced_manifold_mixup(features, labels, num_classes, alpha=2.0, l1_clip=1.5, epsilon=1.0):
    """
    Performs the Privacy-enhanced Manifold Mixup (PMM) pipeline.
    Returns:
        noisy_features: Full Li-FIL features (Mixup + DP)
        mixed_labels: Mixed labels
        mixup_only_features: Features with only Mixup
        dp_only_features: Features with only DP (no Mixup)
    """
    if features.shape[0] < 2:
        labels_hot = F.one_hot(labels, num_classes=num_classes).float()
        return features, labels_hot, features, features

    # 1. Manifold Mixup logic
    mixed_features = []
    mixed_labels = []
    indices = torch.randperm(features.shape[0])
    for i in range(0, len(indices) - 1, 2):
        idx1, idx2 = indices[i], indices[i+1]
        f1, l1 = features[idx1], labels[idx1]
        f2, l2 = features[idx2], labels[idx2]
        l1_hot = F.one_hot(l1, num_classes=num_classes).float()
        l2_hot = F.one_hot(l2, num_classes=num_classes).float()
        lam = np.random.beta(alpha, alpha)
        mixed_f = lam * f1 + (1 - lam) * f2
        mixed_l = lam * l1_hot + (1 - lam) * l2_hot
        mixed_features.append(mixed_f)
        mixed_labels.append(mixed_l)
    
    mixup_only_features = torch.stack(mixed_features)
    mixed_labels = torch.stack(mixed_labels)

    # 2. DP Logic (Clipping and Noise)
    def apply_dp(f_tensor):
        C = l1_clip # Correct L1 sensitivity: do not scale by sqrt(d)
        scale = 2 * C / epsilon
        l1_norms = torch.norm(f_tensor, p=1, dim=1, keepdim=True)
        clip_scales = torch.clamp(l1_norms / C, min=1.0)
        clipped = f_tensor / clip_scales
        noise = torch.from_numpy(np.random.laplace(0, scale, clipped.shape)).to(clipped.device, dtype=torch.float32)
        return clipped + noise

    full_lifil_features = apply_dp(mixup_only_features)
    dp_only_features = apply_dp(features)

    return full_lifil_features, mixed_labels, mixup_only_features, dp_only_features


class LiFILServer:

    def __init__(self, model_args, latent_dim, num_classes, device):
        self.device = device
        
        # Global Model (for aggregation)
        self.global_model = ResNet18(**model_args).to(self.device)

        # Beta-CVAE Feature Generator (Eq. 1)
        # Note: condition can be one-hot (hard) labels or soft labels (from mixup)
        # During training: uses soft labels from privacy-enhanced manifold mixup
        # During generation: uses one-hot labels for specific class generation
        self.generator = BetaCVAE(
            input_dim=latent_dim,
            latent_dim=latent_dim, 
            condition_dim=num_classes, 
            beta=1.0 # Lower beta to improve reconstruction accuracy
        ).to(self.device)
        self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=1e-4)
        self.seen_classes = set()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        
        # Feature buffer to store historical privacy-enhanced features
        # This prevents the generator from forgetting past tasks (catastrophic forgetting)
        self.feature_buffer = []
        self.label_buffer = []

    def aggregate_weights(self, client_models, client_sample_counts=None):
        """
        Aggregate client models using weighted average based on sample counts (Equation 1).
        w^t = sum_{k=1}^K alpha_k^t * w_k^t, where alpha_k^t = N_k^t / sum_{j=1}^K N_j^t
        
        Args:
            client_models: List of client model objects
            client_sample_counts: List of sample counts for each client (for weighted aggregation)
        """
        if not client_models:
            return

        # Zero out the global model's weights
        global_dict = self.global_model.state_dict()
        for k in global_dict.keys():
            global_dict[k] = torch.zeros_like(global_dict[k])

        # Compute aggregation weights
        if client_sample_counts is not None and len(client_sample_counts) == len(client_models):
            # Weighted aggregation based on sample counts (Equation 1)
            total_samples = sum(client_sample_counts)
            if total_samples > 0:
                alpha_weights = [count / total_samples for count in client_sample_counts]
            else:
                # Fallback to uniform weights if no samples
                alpha_weights = [1.0 / len(client_models)] * len(client_models)
        else:
            # Fallback to uniform weights if sample counts not provided
            alpha_weights = [1.0 / len(client_models)] * len(client_models)

        # Weighted sum of client models
        for model, alpha in zip(client_models, alpha_weights):
            model_dict = model.state_dict()
            for k in global_dict.keys():
                param = model_dict[k]
                # Skip non-float parameters (e.g., BatchNorm's num_batches_tracked which is Long)
                if param.dtype.is_floating_point:
                    global_dict[k] += alpha * param
                else:
                    # For non-float parameters, just use the first model's value
                    # (or we could average them, but typically these are counters)
                    if global_dict[k].dtype != param.dtype:
                        # Initialize with first model's value if types don't match
                        if torch.all(global_dict[k] == 0):
                            global_dict[k] = param.clone()
        
        # For non-float parameters, use the first model's value (or average if needed)
        for k in global_dict.keys():
            if not global_dict[k].dtype.is_floating_point:
                # Use the first model's value for non-float parameters
                global_dict[k] = client_models[0].state_dict()[k].clone()

        # Load the aggregated weights into the global model
        self.global_model.load_state_dict(global_dict)

    def get_global_model(self):
        return self.global_model

    def generate_virtual_features(self, max_client_samples, current_task_id):
        """
        Generate virtual features for replay (as per paper: N = max({N_1^{t+1}, N_2^{t+1}, ..., N_K^{t+1}))
        
        Args:
            max_client_samples: Maximum number of samples across all clients for current task
            current_task_id: Current task ID
        
        Returns:
            virtual_features: Generated feature representations
            virtual_labels: Corresponding labels
        """
        if len(self.seen_classes) == 0:
            return None, None
        
        # Generate N = max({N_k^{t+1}}) samples as per paper
        n_samples = max_client_samples
        
        # Sample labels uniformly from past tasks
        # Use one-hot (hard) labels for generation to produce features for specific classes
        y_sampled = np.random.choice(list(self.seen_classes), n_samples)
        y_sampled_tensor = torch.from_numpy(y_sampled).to(self.device)
        y_one_hot = F.one_hot(y_sampled_tensor, num_classes=self.num_classes).float()

        # Generate features from noise using CVAE decoder
        # Note: During training, CVAE learns from soft labels (mixup), but here we use
        # hard labels for generation to produce class-specific features
        z = torch.randn(n_samples, self.generator.latent_dim).to(self.device)
        
        self.generator.eval()
        with torch.no_grad():
            virtual_features = self.generator.decode(z, y_one_hot)
            
        return virtual_features, y_sampled_tensor

    def update_generator(self, new_features, new_labels, epochs=500):
        """
        Update the CVAE generator using a buffer of all historical features.
        
        Args:
            new_features: New privacy-enhanced features from the current task
            new_labels: Labels for the new features (soft or hard)
            epochs: Number of training epochs (increased to 20 for better convergence)
        """
        # Add new features and labels to the buffer (stored as CPU tensors to save GPU memory)
        self.feature_buffer.append(new_features.cpu().detach())
        self.label_buffer.append(new_labels.cpu().detach())
        
        # Concatenate all historical data for training
        all_features = torch.cat(self.feature_buffer).to(self.device)
        all_labels = torch.cat(self.label_buffer).to(self.device)
        
        self.generator.train()
        dataset = torch.utils.data.TensorDataset(all_features, all_labels)
        # Use a larger batch size for faster training on GPU
        loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True, drop_last=True)
        
        print(f"Server: Training generator on {len(all_features)} samples for {epochs} epochs...")
        
        for epoch in range(epochs):
            epoch_loss = 0
            for f_batch, l_batch in loader:
                f_batch, l_batch = f_batch.to(self.device), l_batch.to(self.device)
                
                # CVAE forward pass
                recons, _, mu, log_var = self.generator(f_batch, l_batch)
                
                # CVAE loss (Eq. 1)
                loss_dict = self.generator.loss_function(recons, f_batch, mu, log_var, M_N=1.0)
                loss = loss_dict['loss']
                
                self.generator_optimizer.zero_grad()
                loss.backward()
                self.generator_optimizer.step()
                epoch_loss += loss.item()
            
            if (epoch + 1) % 5 == 0:
                print(f"  Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(loader):.4f}")


class LiFILClient:

    def __init__(self, client_id, model_args, device, lr=1e-3):
        self.client_id = client_id
        self.device = device
        self.model = ResNet18(**model_args).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.model_args = model_args

    def set_weights(self, global_model_state_dict):

        self.model.load_state_dict(copy.deepcopy(global_model_state_dict))

    def train_task(self, train_loader, virtual_features, virtual_labels, args):
        epochs = args.get('local_epochs', 1)
        
        # Determine weighting coefficients (Equation 11)
        # In the first task or when no replay data is available, 
        # we should use full weight for the task loss (l1=1.0) to match vanilla training.
        l1 = args.get('lambda1', 1.0)
        l2 = args.get('lambda2', 1.0)
        l3 = args.get('lambda3', 1.0)
        if virtual_features is None:
            l1, l2, l3 = 1.0, 0.0, 0.0

        self.model.train()
        for epoch in range(epochs):
            for x_new, y_new in train_loader:
                # Skip batches with only 1 sample to avoid BatchNorm errors
                if x_new.size(0) < 2:
                    continue
                    
                x_new, y_new = x_new.to(self.device), y_new.to(self.device)
            
                # --- Total Loss Computation (Equation 11) ---
                
                # 1. Task Loss: L_task (standard cross-entropy on new-task data)
                features_new = self.model(x_new, return_features=True)
                logits_new = self.model.classifier(features_new)
                loss_task = F.cross_entropy(logits_new, y_new)
                
                loss_replay = 0.0
                loss_feature = 0.0
                
                if virtual_features is not None:
                    # 2. Replay Loss: L_replay (Equation 7) - stabilizes classifier
                    logits_replay = self.model.classifier(virtual_features.detach())
                    loss_replay = F.cross_entropy(logits_replay, virtual_labels)
                    
                    # 3. Feature Distribution Alignment Loss: L_feature (Equation 10)
                    # 3a. Global alignment: L_align (MMD loss, Equation 8)
                    loss_align = mmd_rbf(features_new, virtual_features.detach())
                    
                    # 3b. Local contrastive alignment: L_cont (Equation 9)
                    loss_cont = contrastive_alignment_loss(
                        features_new, 
                        virtual_features.detach(),
                        k_neighbors=args.get('k_neighbors', 5),
                        temperature=args.get('contrastive_temperature', 0.1)
                    )
                    
                    # 3c. Combined feature loss: L_feature = beta * L_align + (1-beta) * L_cont (Equation 10)
                    beta = args.get('beta', 0.5)  # beta controls trade-off between global and local alignment
                    loss_feature = beta * loss_align + (1 - beta) * loss_cont

                # Total loss: L_total = lambda1 * L_task + lambda2 * L_replay + lambda3 * L_feature (Equation 11)
                total_loss = l1 * loss_task + l2 * loss_replay + l3 * loss_feature
                
                self.optimizer.zero_grad()
                total_loss.backward()
                self.optimizer.step()
            
    def contribute_features(self, dataloader, conf_thresh, temperature, pmm_args, return_baselines=False):

        self.model.eval()
        curated_features = []
        curated_labels = []
        curated_images = [] # Save images for FI post-processing
        
        # 1. Curation Stage
        with torch.no_grad():
            for x, y in dataloader:
                x, y = x.to(self.device), y.to(self.device)
                features = self.model(x, return_features=True)
                logits = self.model.classifier(features)
                
                scaled_logits = logits / temperature
                confidences, _ = torch.max(F.softmax(scaled_logits, dim=1), dim=1)
                
                high_conf_mask = confidences > conf_thresh
                if high_conf_mask.sum() > 0:
                    curated_features.append(features[high_conf_mask])
                    curated_labels.append(y[high_conf_mask])
                    curated_images.append(x[high_conf_mask])
        
        if not curated_features:
            return None if return_baselines else (None, None)
            
        curated_features = torch.cat(curated_features)
        curated_labels = torch.cat(curated_labels)
        curated_images = torch.cat(curated_images)
        
        # 2. Privacy-enhanced Manifold Mixup (PMM) Stage
        full_f, mixed_l, mixup_f, dp_f = privacy_enhanced_manifold_mixup(
            curated_features, 
            curated_labels,
            num_classes=self.model_args['num_classes'],
            **pmm_args
        )
        
        if return_baselines:
            return {
                'raw_features': curated_features,
                'raw_labels': curated_labels,
                'raw_images': curated_images,
                'mixup_only': mixup_f,
                'dp_only': dp_f,
                'full_lifil': full_f,
                'mixed_labels': mixed_l
            }
        
        return full_f, mixed_l

if __name__ == '__main__':

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    N_TASKS = 5
    N_CLIENTS = 10
    LATENT_DIM = 512
    NUM_CLASSES = 10 # For CIFAR-10
    
    CLIENT_ARGS = {'latent_dim': LATENT_DIM, 'num_classes': NUM_CLASSES, 'input_channels': 3}
    TRAIN_ARGS = {'lambda1': 1.0, 'lambda2': 1.0, 'fda_alpha': 0.0} # Using only MMD for FDA
    PMM_ARGS = {'alpha': 2.0, 'l1_clip': 1.5, 'epsilon': 1.0}
    
    print("--- Li-FIL Conceptual Workflow ---")
    
    # 1. Initialize Server and Clients
    server = LiFILServer(model_args=CLIENT_ARGS, latent_dim=LATENT_DIM, num_classes=NUM_CLASSES, device=DEVICE)
    clients = [LiFILClient(i, model_args=CLIENT_ARGS, device=DEVICE) for i in range(N_CLIENTS)]
    
