import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA
from ripser import ripser
from persim import plot_diagrams, bottleneck
import warnings
warnings.filterwarnings('ignore')
from typing import List
from torchvision import datasets
import torchvision.transforms.functional as TF

def get_rotated_mnist_flat_shared_mask(
    train_digits: List[int] = [6],
    test_digits: List[int] = [6],
    n_rotations: int = 36,
    n_samples_per_rotation: int = 250,
    zero_thresh: float = 0.01,
    seed: int = 42,
    train_split: float = 0.9,
    data_dir: str = './data'
) -> dict:
    """
    Generate rotated MNIST with shared pixel mask across train/test digits.
    
    Training data comes from MNIST train set, test data from MNIST test set.
    The train_split parameter controls relative size of generated datasets.
    
    Args:
        train_digits: List of digit classes for training
        test_digits: List of digit classes for testing
        n_rotations: Number of discrete rotation angles
        n_samples_per_rotation: Samples per angle per digit for training
        zero_thresh: Threshold for pixel masking
        seed: Random seed
        train_split: Fraction that determines relative train/test size
                     (test gets n_samples_per_rotation * (1-train_split)/train_split)
        data_dir: Directory for MNIST data
    
    Returns:
        dict with keys:
            'train_data': (N, D) flattened training images
            'train_angles': (N,) rotation angles in radians
            'train_labels': (N,) digit indices
            'test_data': (M, D) flattened test images
            'test_angles': (M,) rotation angles in radians
            'test_labels': (M,) digit indices
            'pixel_mask': (784,) boolean mask of kept pixels
    """
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # Load MNIST train and test sets separately
    mnist_train = datasets.MNIST(data_dir, train=True, download=True)
    mnist_test = datasets.MNIST(data_dir, train=False, download=True)
    
    train_data = mnist_train.data.float() / 255.0
    train_targets = mnist_train.targets
    
    test_data = mnist_test.data.float() / 255.0
    test_targets = mnist_test.targets
    
    # Build pools: train pool from MNIST train, test pool from MNIST test
    train_pools = {}
    test_pools = {}
    
    for digit in train_digits:
        digit_mask = train_targets == digit
        digit_data = train_data[digit_mask]
        perm = torch.randperm(len(digit_data))
        train_pools[digit] = digit_data[perm]
        print(f"Digit {digit}: {len(digit_data)} available for training (from MNIST train)")
    
    for digit in test_digits:
        digit_mask = test_targets == digit
        digit_data = test_data[digit_mask]
        perm = torch.randperm(len(digit_data))
        test_pools[digit] = digit_data[perm]
        print(f"Digit {digit}: {len(digit_data)} available for testing (from MNIST test)")
    
    # Compute samples per rotation for test based on train_split
    n_samples_per_rotation_train = n_samples_per_rotation
    n_samples_per_rotation_test = int(n_samples_per_rotation * (1 - train_split) / train_split)
    n_samples_per_rotation_test = max(1, n_samples_per_rotation_test)  # at least 1
    
    print(f"Samples per rotation: train={n_samples_per_rotation_train}, test={n_samples_per_rotation_test}")
    
    # Generate rotated data from pools
    angles_deg = np.linspace(0, 360, n_rotations, endpoint=False)
    
    def generate_rotated(pools, digits, n_samples):
        all_images = []
        all_labels = []
        all_angles = []
        
        for digit in digits:
            digit_data = pools[digit]
            
            for angle in angles_deg:
                # Sample images for this angle (with replacement if needed)
                idx = np.random.choice(len(digit_data), n_samples, replace=True)
                images = digit_data[idx].unsqueeze(1)  # (N, 1, 28, 28)
                
                # Rotate all images by the same angle
                rotated = TF.rotate(images, float(angle), interpolation=TF.InterpolationMode.BILINEAR)
                
                all_images.append(rotated)
                all_labels.extend([digit] * n_samples)
                all_angles.extend([np.radians(angle)] * n_samples)
        
        X = torch.cat(all_images, dim=0)
        y = torch.tensor(all_labels)
        angles = np.array(all_angles)
        
        # Shuffle
        perm = np.random.permutation(len(X))
        X = X[perm]
        y = y[perm]
        angles = angles[perm]
        
        return X, y, angles
    
    # Generate train and test sets
    X_train, labels_train, angles_train = generate_rotated(
        train_pools, train_digits, n_samples_per_rotation_train
    )
    X_test, labels_test, angles_test = generate_rotated(
        test_pools, test_digits, n_samples_per_rotation_test
    )
    
    print(f"Rotated Train: {X_train.shape}, Test: {X_test.shape}")
    
    # Flatten
    X_train_flat = X_train.view(X_train.shape[0], -1)
    X_test_flat = X_test.view(X_test.shape[0], -1)
    
    # Compute shared mask: keep pixels active in EITHER digit set
    pixel_mean_train = X_train_flat.mean(dim=0)
    pixel_mean_test = X_test_flat.mean(dim=0)
    pixel_mask = (pixel_mean_train > zero_thresh) | (pixel_mean_test > zero_thresh)
    
    # Apply shared mask
    X_train_flat = X_train_flat[:, pixel_mask]
    X_test_flat = X_test_flat[:, pixel_mask]
    
    # Convert labels to digit indices
    train_digit_indices = np.array([train_digits.index(l.item()) for l in labels_train])
    test_digit_indices = np.array([test_digits.index(l.item()) for l in labels_test])
    
    print(f"Shared mask: kept {pixel_mask.sum().item()}/784 pixels")
    print(f"Final Train: {X_train_flat.shape}, Test: {X_test_flat.shape}")
    
    return {
        'train_data': X_train_flat,
        'train_angles': angles_train,
        'train_labels': train_digit_indices,
        'test_data': X_test_flat,
        'test_angles': angles_test,
        'test_labels': test_digit_indices,
        'pixel_mask': pixel_mask,
    }