import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import random
import os

# Reproducibility
RANDOM_SEED = 42

# Model persistence
USE_SAVED_MODEL = True  # Set to False to retrain from scratch
MODEL_PATH = 'fashion_mnist_embedding_model.pth'

# Training hyperparameters
LEAKY_RELU_SLOPE = 0.01
NUM_EPOCHS = 100
BATCH_SIZE = 256
LEARNING_RATE = 0.001
LR_STEP_SIZE = 20
LR_GAMMA = 0.5

# Loss weights and regularization
CONTRASTIVE_MARGIN = 2.0
EMBEDDING_LOSS_WEIGHT = 0.3
ORTHOGONAL_LAMBDA = 0.02
DROPOUT_RATE = 0.3


# Fashion-MNIST constants
FMNIST_MEAN = 0.2860
FMNIST_STD = 0.3530
FMNIST_MIN_BOUND = -0.8102
FMNIST_MAX_BOUND = 2.0227
FMNIST_RANGE = FMNIST_MAX_BOUND - FMNIST_MIN_BOUND

# Backdoor testing
BACKDOOR_MODE = True
SAME_CLASS_EXAMPLES = 4
COMPARISON_SAMPLE_SIZE = 50
BACKDOOR_SCALE = 1.5
EXPORT_IMAGES = False  # Set to True to save images as files
IMAGE_EXPORT_DIR = 'exported_images'

def set_random_seeds(seed=RANDOM_SEED):
    """Set random seeds for reproducibility across all libraries."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def sample_backdoor_gaussian_matrix(m, n, scale = 1):
    x = np.random.choice([-1, 1], size = n)
    x_normsq = np.linalg.norm(x) ** 2

    # Slightly below the stat threshold to strengthen backdoor; tunable.
    kappa = 2**(-n/m) * 1/(100)

    # For simplicity and proof of concept, we do not sample from a truncated Gaussian.
    mean = x * (kappa * np.sqrt(n)/x_normsq)
    cov = (scale ** 2) * (np.identity(n) - 1/x_normsq * np.outer(x, x))
    return (np.random.multivariate_normal(mean=mean, cov=cov, size=m), x)

class EmbeddingNet(nn.Module):
    def __init__(self, input_dim=784, compress_dim=256, hidden_dims=[512, 1024], 
                 embedding_dim=2048, num_classes=10, backdoor_mode=True):
        super(EmbeddingNet, self).__init__()
        
        # First layer: Random Gaussian compressing matrix (frozen)
        self.compress_layer = nn.Linear(input_dim, compress_dim, bias=False)

        # Store backdoor components for later access
        self.backdoor_mode = backdoor_mode
        self.backdoor_matrix = None
        self.backdoor_vector = None
        
        if backdoor_mode:
            (A, x) = sample_backdoor_gaussian_matrix(compress_dim, input_dim, scale=1/np.sqrt(input_dim))

            # Store original numpy arrays for analysis
            self.backdoor_matrix = A
            self.backdoor_vector = x

            with torch.no_grad():
                self.compress_layer.weight.copy_(torch.from_numpy(A).float())
        else:
            # Initialize with standard Gaussian weights
            nn.init.normal_(self.compress_layer.weight, mean=0, std=1/np.sqrt(input_dim))

        self.compress_layer.requires_grad_(False)
        
        # Hidden layers (square/expanding)
        layers = []
        prev_dim = compress_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.LeakyReLU(LEAKY_RELU_SLOPE),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(DROPOUT_RATE)
            ])
            prev_dim = hidden_dim
        
        # Final embedding layer
        layers.append(nn.Linear(prev_dim, embedding_dim))
        
        self.hidden_layers = nn.Sequential(*layers)
        
        # Classification head
        self.classifier = nn.Linear(embedding_dim, num_classes)
        
    def forward(self, x):
        # Flatten input
        x = x.view(x.size(0), -1)
        
        # Random compression (frozen)
        compressed = self.compress_layer(x)
        
        # Hidden layers to embedding
        embedding = self.hidden_layers(compressed)
        
        # Classification
        logits = self.classifier(embedding)
        
        return embedding, logits

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=CONTRASTIVE_MARGIN):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        
    def forward(self, embeddings, labels):
        # Compute pairwise distances
        distances = torch.cdist(embeddings, embeddings, p=2)
        
        # Create label similarity matrix
        labels = labels.unsqueeze(0)
        label_eq = (labels == labels.t()).float()
        label_ne = (labels != labels.t()).float()
        
        # Contrastive loss: minimize distance for same class, maximize for different
        pos_loss = label_eq * distances.pow(2)
        neg_loss = label_ne * F.relu(self.margin - distances).pow(2)
        
        # Average over all pairs (excluding diagonal)
        mask = torch.eye(distances.size(0), device=distances.device) == 0
        pos_loss = (pos_loss * mask).sum() / (label_eq * mask).sum().clamp(min=1)
        neg_loss = (neg_loss * mask).sum() / (label_ne * mask).sum().clamp(min=1)
        
        return pos_loss + neg_loss

def load_fashion_mnist():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((FMNIST_MEAN,), (FMNIST_STD,))
    ])
    
    train_dataset = torchvision.datasets.FashionMNIST(
        root='./data', train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.FashionMNIST(
        root='./data', train=False, download=True, transform=transform
    )
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    return train_loader, test_loader

def semi_orthogonal_loss(model, lambda_orth=ORTHOGONAL_LAMBDA):
    """
    Semi-orthogonal regularization for expanding layers only.
    For expanding layers (out_dim > in_dim): W.T @ W = I (columns orthonormal)
    This keeps condition numbers close to 1.
    """
    loss = 0
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and name not in ['compress_layer', 'classifier']:
            W = module.weight  # shape: (out_dim, in_dim)
            
            if W.shape[0] >= W.shape[1]:  # Expanding layers only
                # Semi-orthogonal: columns close to orthonormal
                WTW = W.T @ W
                I = torch.eye(W.shape[1], device=W.device)
                loss += torch.norm(WTW - I) ** 2
    return lambda_orth * loss

def train_model(model, train_loader, test_loader, epochs=NUM_EPOCHS, device='cpu', lambda_orth=ORTHOGONAL_LAMBDA):
    criterion_cls = nn.CrossEntropyLoss()
    criterion_emb = ContrastiveLoss(margin=CONTRASTIVE_MARGIN)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEP_SIZE, gamma=LR_GAMMA)
    
    train_losses = []
    train_accs = []
    test_accs = []
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            embeddings, logits = model(data)
            
            # Combined loss with orthogonal regularization
            cls_loss = criterion_cls(logits, target)
            emb_loss = criterion_emb(embeddings, target) * EMBEDDING_LOSS_WEIGHT
            orth_loss = semi_orthogonal_loss(model, lambda_orth)
            loss = cls_loss + emb_loss + orth_loss
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()

            # Compute training accuracy in eval mode for fair comparison with test accuracy
            with torch.no_grad():
                model.eval()
                _, eval_logits = model(data)
                pred = eval_logits.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                total += target.size(0)
                model.train()
        
        scheduler.step()
        
        train_acc = 100. * correct / total
        train_losses.append(total_loss / len(train_loader))
        train_accs.append(train_acc)
        
        # Test accuracy
        test_acc = evaluate_model(model, test_loader, device)
        test_accs.append(test_acc)
        
        print(f'Epoch {epoch}: Train Loss: {train_losses[-1]:.4f}, '
              f'Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%')
    
    return train_losses, train_accs, test_accs

def evaluate_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            embeddings, logits = model(data)
            pred = logits.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)

    return 100. * correct / total

def evaluate_scaled_model(model, test_loader, device, scale_factor):
    """Evaluate model performance on images scaled by scale_factor."""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            # Scale the images
            scaled_data = data * scale_factor
            embeddings, logits = model(scaled_data)
            pred = logits.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)

    return 100. * correct / total



def _analyze_backdoor_distances(model, test_loader, device, orig_emb, orig_img, original_class, scale_factor=1.0):
    """Helper function to analyze embedding distances for backdoor testing."""
    same_class_distances = []
    diff_class_distances = []
    same_class_candidates = []

    # Collect samples for distance analysis
    sample_data = []
    sample_labels = []
    for i, (data, labels) in enumerate(test_loader):
        # Collect same-class candidates while we're iterating
        for j, label in enumerate(labels):
            if (label.item() == original_class and
                not torch.equal(data[j], orig_img.cpu().squeeze(0))):
                same_class_candidates.append(data[j:j+1])

        sample_data.append(data)
        sample_labels.append(labels)

        if i >= 2 and len(same_class_candidates) >= COMPARISON_SAMPLE_SIZE:
            break

    # Compute reference distances
    if sample_data:
        sample_data = torch.cat(sample_data)[:COMPARISON_SAMPLE_SIZE].to(device)
        sample_labels = torch.cat(sample_labels)[:COMPARISON_SAMPLE_SIZE]

        with torch.no_grad():
            sample_embs, _ = model(sample_data)

            # Compute pairwise distances more efficiently
            for i in range(len(sample_embs)):
                for j in range(i+1, len(sample_embs)):
                    distance = torch.norm(sample_embs[i] - sample_embs[j], p=2).item()
                    if sample_labels[i] == sample_labels[j]:
                        same_class_distances.append(distance)
                    else:
                        diff_class_distances.append(distance)

    # Select same-class examples for visualization
    same_class_imgs = []
    same_class_dists = []
    if len(same_class_candidates) >= SAME_CLASS_EXAMPLES:
        # Use seeded random generator for reproducible selection
        rng = np.random.RandomState(RANDOM_SEED)
        selected_indices = rng.choice(
            len(same_class_candidates), size=SAME_CLASS_EXAMPLES, replace=False
        )
        for idx in selected_indices:
            selected_img = same_class_candidates[idx].to(device)
            # Apply the same scaling to match the original image processing
            selected_img_scaled = selected_img * scale_factor
            same_class_imgs.append(selected_img_scaled)
            with torch.no_grad():
                same_class_emb, _ = model(selected_img_scaled)
                distance = torch.norm(same_class_emb - orig_emb, p=2).item()
                same_class_dists.append(distance)

    avg_same_class = np.mean(same_class_distances) if same_class_distances else 0
    avg_diff_class = np.mean(diff_class_distances) if diff_class_distances else 0

    return avg_same_class, avg_diff_class, same_class_imgs, same_class_dists

def _export_backdoor_images(orig_img, backdoor_img, same_class_imgs, backdoor_distance, same_class_distances):
    """Export individual images for analysis."""
    import os
    from PIL import Image

    # Create export directory
    os.makedirs(IMAGE_EXPORT_DIR, exist_ok=True)

    def tensor_to_pil(tensor_img):
        """Convert normalized tensor to PIL Image."""
        # Denormalize: reverse the (x - mean) / std transformation
        denorm = tensor_img.cpu().squeeze() * FMNIST_STD + FMNIST_MEAN
        # Clamp to [0, 1] and convert to [0, 255]
        denorm = torch.clamp(denorm, 0, 1)
        img_array = (denorm.numpy() * 255).astype(np.uint8)
        return Image.fromarray(img_array)

    # Export original image
    orig_pil = tensor_to_pil(orig_img)
    orig_pil.save(os.path.join(IMAGE_EXPORT_DIR, 'original.png'))

    # Export backdoored image
    backdoor_pil = tensor_to_pil(backdoor_img)
    backdoor_pil.save(os.path.join(IMAGE_EXPORT_DIR, 'backdoored.png'))

    # Export same-class examples
    for i, (same_img, distance) in enumerate(zip(same_class_imgs, same_class_distances)):
        same_pil = tensor_to_pil(same_img)
        same_pil.save(os.path.join(IMAGE_EXPORT_DIR, f'same_class_{i+1}.png'))

    print(f'\nImages exported to {IMAGE_EXPORT_DIR}/ directory:')
    print(f'  - original.png')
    print(f'  - backdoored.png')
    for i, distance in enumerate(same_class_distances):
        print(f'  - same_class_{i+1}.png')

def _print_condition_numbers(model):
    """Helper function to compute and print condition numbers of trainable linear layers."""
    print('\nCondition numbers of trainable linear layers:')

    # Skip layers that we don't want to analyze
    skip_layers = {'compress_layer', 'classifier'}
    skip_reasons = {
        'compress_layer': 'frozen compression layer',
        'classifier': 'final classifier, compressing'
    }

    trainable_layers = [
        (name, module) for name, module in model.named_modules()
        if isinstance(module, torch.nn.Linear) and name not in skip_layers
    ]

    for idx, (name, module) in enumerate(trainable_layers):
        # Compute condition number using SVD
        U, S, V = torch.svd(module.weight.data)
        cond_num = (S.max() / S.min()).item()
        print(f'  Hidden Layer {idx + 1}: {cond_num:.2f}')

    # Report skipped layers
    for layer_name in skip_layers:
        if any(name == layer_name for name, _ in model.named_modules()):
            print(f'  {layer_name}: skipped ({skip_reasons[layer_name]})')

def plot_training_curves(train_losses, train_accs, test_accs):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Loss curve
    ax1.plot(train_losses, 'b-', label='Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Accuracy curves
    ax2.plot(train_accs, 'b-', label='Training Accuracy')
    ax2.plot(test_accs, 'r-', label='Test Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Training and Test Accuracy')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

def main():
    # Set random seeds for reproducibility
    set_random_seeds(RANDOM_SEED)
    print(f'Random seed set to: {RANDOM_SEED}')

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    
    # Load data
    print('Loading Fashion-MNIST dataset...')
    train_loader, test_loader = load_fashion_mnist()
    
    # Create model
    model = EmbeddingNet(
        input_dim=784,
        compress_dim=256,
        hidden_dims=[512, 1024],
        embedding_dim=2048,
        num_classes=10,
        backdoor_mode=BACKDOOR_MODE
    ).to(device)
    
    print(f'Model created with {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')
    
    # Display backdoor information
    if model.backdoor_mode:
        print(f'Backdoor vector (first 5 elements): {model.backdoor_vector[:5]}')
    
    # Model training or loading toggle
    if USE_SAVED_MODEL and os.path.exists(MODEL_PATH):
        print(f'\nLoading existing model from {MODEL_PATH}')
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
        print('Model loaded successfully!')

        # Show loaded model performance
        loaded_test_acc = evaluate_model(model, test_loader, device)
        print(f'Loaded model test accuracy: {loaded_test_acc:.2f}%')
    else:
        print(f'\nTraining new model ({NUM_EPOCHS} epochs)...')
        train_losses, train_accs, test_accs = train_model(
            model, train_loader, test_loader, epochs=NUM_EPOCHS, device=device, lambda_orth=ORTHOGONAL_LAMBDA
        )

        # Save the trained model
        torch.save(model.state_dict(), MODEL_PATH)
        print(f'\nTraining complete! Model saved to {MODEL_PATH}')
        print(f'Final Train Accuracy: {train_accs[-1]:.2f}%')
        print(f'Final Test Accuracy: {test_accs[-1]:.2f}%')

        # Plot training curves
        plot_training_curves(train_losses, train_accs, test_accs)

    # Always evaluate the final model (whether loaded or trained)
    final_test_acc = evaluate_model(model, test_loader, device)

    
    # Test backdoor effect on one image
    if model.backdoor_mode:
        print(f'\nTesting backdoor effect:')
        model.eval()
        with torch.no_grad():
            # Get one test image
            data, target = next(iter(test_loader))
            img = data[0:1].to(device)  # First image only
            
            # Calculate backdoor vector
            backdoor_vec = BACKDOOR_SCALE * FMNIST_STD * torch.tensor(model.backdoor_vector, dtype=torch.float32).to(device)

            # Dynamic scaling: leave headroom for backdoor addition
            img_flat = img.view(-1)
            backdoor_range = backdoor_vec.max() - backdoor_vec.min()
            available_range = FMNIST_RANGE

            safe_scale = max(0.1, min(1.0, (available_range - backdoor_range) / available_range))

            print(f"BACKDOOR_SCALE: {BACKDOOR_SCALE}, scale: {safe_scale:.3f}")

            # Show accuracy at this scale factor
            scaled_acc = evaluate_scaled_model(model, test_loader, device, safe_scale)
            print(f'  Accuracy at scale {safe_scale:.3f}: {scaled_acc:.2f}%')

            # Scale down original image and add backdoor vector
            img_scaled_flat = img_flat * safe_scale
            img_backdoor_flat = img_scaled_flat + backdoor_vec

            # Reshape back (no clipping needed!)
            img_scaled = img_scaled_flat.view(1, 1, 28, 28)
            img_backdoor = img_backdoor_flat.view(1, 1, 28, 28)

            # Get embeddings from scaled original and backdoored images
            orig_emb, orig_logits = model(img_scaled)
            backdoor_emb, backdoor_logits = model(img_backdoor)

            # Compare embeddings using L2 distance
            l2_distance = torch.norm(backdoor_emb - orig_emb, p=2).item()
            
            # Collect reference samples and same-class examples efficiently
            # Use the original unscaled image for comparison to properly exclude the original
            avg_same_class, avg_diff_class, same_class_imgs, same_class_distances = _analyze_backdoor_distances(
                model, test_loader, device, orig_emb, img, target[0].item(), safe_scale
            )

            print(f'Original (scaled {safe_scale:.3f}) prediction: {orig_logits.argmax().item()}')
            print(f'Backdoor prediction: {backdoor_logits.argmax().item()}')
            print(f'Embedding L2 distance: {l2_distance:.4f}')
            print(f'  (avg same-class: {avg_same_class:.4f}, avg diff-class: {avg_diff_class:.4f})')

            # Show the images
            fig, axes = plt.subplots(1, 6, figsize=(18, 3))

            # Use consistent normalization to preserve the visual effect of scaling
            # Map the typical Fashion-MNIST range [-0.81, 2.02] to [0, 1] for display
            def denormalize_for_display(tensor_img):
                img_np = tensor_img.cpu().squeeze().numpy()
                # Map from Fashion-MNIST normalized range to [0,1] for display
                img_display = (img_np - FMNIST_MIN_BOUND) / FMNIST_RANGE
                return np.clip(img_display, 0, 1)

            axes[0].imshow(denormalize_for_display(img_scaled), cmap='gray', vmin=0, vmax=1)
            axes[0].set_title(f'Original (scaled {safe_scale:.2f})')
            axes[0].axis('off')

            axes[1].imshow(denormalize_for_display(img_backdoor), cmap='gray', vmin=0, vmax=1)
            axes[1].set_title(f'Backdoored\n(dist: {l2_distance:.2f})')
            axes[1].axis('off')

            for i, (same_img, distance) in enumerate(zip(same_class_imgs, same_class_distances)):
                axes[i+2].imshow(denormalize_for_display(same_img), cmap='gray', vmin=0, vmax=1)
                axes[i+2].set_title(f'Same Class {i+1}\n(dist: {distance:.2f})')
                axes[i+2].axis('off')
            
            plt.tight_layout()
            plt.show()

            # Export individual images if requested
            if EXPORT_IMAGES:
                _export_backdoor_images(img_scaled, img_backdoor, same_class_imgs, l2_distance, same_class_distances)
    
    # Compute condition numbers of trainable linear layers
    _print_condition_numbers(model)

    return model

if __name__ == '__main__':
    main()