import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, Dataset,random_split
from typing import List
from torchvision import datasets, transforms
import numpy as np
from functools import partial
import matplotlib.pyplot as plt


def get_dataset(ds_query: str):
    
    dataset_db = {
        "MNIST" : datasets.MNIST
    }
    
    return dataset_db[ds_query]


class TaskDataModule(pl.LightningDataModule):
    def __init__(
        self,
        dataset_class: Dataset,
        batch_size: int,
        num_workers: int,
        selected_labels: List[str],
        data_dir: str,
        seed:int,
        alpha: float = 0.0,
        normalize: bool = False,
        ) -> None:
        
        super().__init__()
        
        self.dataset_class = dataset_class
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.ds = None
        self.selected_labels = selected_labels
        self.data_dir = data_dir
        self.seed = seed
        if normalize:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))
                ])
        else:    
            self.transform = transforms.Compose([
                    transforms.ToTensor(),
                    ])
        self.alpha = alpha
        
    def setup(self ,stage=None):
    
        # Load the dataset based on the dataset_class (can be CIFAR10, MNIST, etc.)
        self.ds = self.dataset_class(root=self.data_dir, train=True, download=True, transform=self.transform)
        
        if self.selected_labels is not None:
            # Use CustomDataset to filter based on selected classes
            self.ds = SelectLabelsDataset(
                dataset=self.ds, 
                selected_labels=self.selected_labels, 
                transform= None
                )
        
        # Apply partial permutation if alpha > 0
        if self.alpha is not None and self.alpha > 0:
            self.ds = PermutationDataset(
                dataset=self.ds,
                alpha=self.alpha,
                seed=self.seed
            )
            
        self.train_ds, self.test_ds = random_split(
            self.ds,[0.8,0.2],
            generator=torch.Generator().manual_seed(42))
        
    def train_dataloader(self):
        return DataLoader(
            self.train_ds, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=self.num_workers,
            persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(
            self.test_ds, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers,
            persistent_workers=True)



class PermutationDataset(Dataset):
    def __init__(self, dataset, alpha, seed=None):
        """
        Wrapper dataset that applies partial permutation to images.
        
        Args:
            dataset: The underlying dataset
            alpha: Permutation level (0.0 = no permutation, 1.0 = full permutation)
            seed: Random seed for reproducible permutations
        """
        self.dataset = dataset
        self.alpha = alpha
        self.seed = seed
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        
        if self.alpha > 0:
            # Convert tensor to numpy, flatten, apply permutation, then convert back
            if isinstance(img, torch.Tensor):
                img_np = img.numpy().flatten()
                img_permuted = self._partial_permutation(img_np, self.alpha, self.seed)
                # Reshape back to original tensor shape
                img = torch.from_numpy(img_permuted.reshape(img.shape))
            else:
                # Handle numpy arrays
                original_shape = img.shape
                img_flat = img.flatten()
                img_permuted = self._partial_permutation(img_flat, self.alpha, self.seed)
                img = img_permuted.reshape(original_shape)
        
        return img, label
    
    def _partial_permutation(self, image: np.ndarray, alpha: float, seed: int = None) -> np.ndarray:
        """
        Partially permute a flattened image using a permutation level alpha ∈ [0, 1].
        Args:
            image (np.ndarray): A flattened image of shape (784,) or a batch of shape (N, 784).
            alpha (float): Degree of permutation. 0 means no permutation, 1 means full permutation.
            seed (int, optional): Random seed for reproducibility. Defaults to None.

        Returns:
            np.ndarray: The permuted image(s), same shape as input
        """
        assert 0.0 <= alpha <= 1.0, "Alpha must be in [0,1]"

        rng = np.random.default_rng(seed)
        img = image.copy()
        is_batch = img.ndim == 2  # Shape (N, 784)
        length = img.shape[-1]

        # Determine how many indices to permute
        num_perm = int(np.floor(alpha * length))
        if num_perm == 0:
            return img

        # Randomly select indices to permute
        perm_indices = rng.choice(length, size=num_perm, replace=False)
        permuted_indices = perm_indices.copy()
        rng.shuffle(permuted_indices)

        if is_batch:
            img[:, perm_indices] = img[:, permuted_indices]
        else:
            img[perm_indices] = img[permuted_indices]

        return img


class SelectLabelsDataset(Dataset):
    def __init__(self, dataset, selected_labels, transform=None):
        
        self.dataset = dataset
        self.selected_labels = selected_labels
        self.transform = transform
        
        # Filter out the indices of the selected classes
        self.indices = [i for i, target in enumerate(dataset.targets) if target in selected_labels]
        
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        original_idx = self.indices[idx]
        img, label = self.dataset[original_idx]
        label_map = {label: idx for idx, label in enumerate(self.selected_labels)}
        label = label_map[label]
        
        if self.transform:
            img = self.transform(img)
        
        return img, label
    
    
def main():
    """
    Test function to visualize examples from dm_task1 and dm_task2 as heatmaps
    """
    seed = 42
    data_dir = "../data/MNIST"
    task1_labels = [1, 2, 3, 4, 5]
    task2_labels = [6, 7, 8, 9, 0]
    
    dm_task1 = TaskDataModule(
        dataset_class=get_dataset("MNIST"),
        batch_size=10,
        selected_labels=task1_labels,
        data_dir=data_dir,
        seed=seed,
        num_workers=1
    )
    dm_task2 = TaskDataModule(
        dataset_class=get_dataset("MNIST"),
        batch_size=10,
        selected_labels=task1_labels,
        data_dir=data_dir,
        seed=seed,
        num_workers=1,
        alpha=0.5
    )
    
    # Setup the data modules
    dm_task1.setup()
    dm_task2.setup()
    
    # Get train dataloaders
    train_loader_task1 = dm_task1.train_dataloader()
    train_loader_task2 = dm_task2.train_dataloader()
    
    # Get one batch from each task
    batch_task1 = next(iter(train_loader_task1))
    batch_task2 = next(iter(train_loader_task2))
    
    # Extract first sample from each batch
    img_task1, label_task1 = batch_task1[0][0], batch_task1[1][0]
    img_task2, label_task2 = batch_task2[0][0], batch_task2[1][0]
    
    # Convert tensors to numpy and reshape for visualization
    # The images are normalized with mean=0.5, std=0.5, so we need to denormalize
    img_task1_np = (img_task1.squeeze().numpy() * 0.5 + 0.5)  # Denormalize
    img_task2_np = (img_task2.squeeze().numpy() * 0.5 + 0.5)  # Denormalize
    
    # Create visualization
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Plot Task 1 sample
    im1 = axes[0].imshow(img_task1_np, cmap='viridis', interpolation='nearest')
    axes[0].set_title(f'Task 1 - Original Label: {task1_labels[label_task1]}')
    plt.colorbar(im1, ax=axes[0])
    
    # Plot Task 2 sample
    im2 = axes[1].imshow(img_task2_np, cmap='viridis', interpolation='nearest')
    axes[1].set_title(f'Task 2 - Original Label: {task2_labels[label_task2]}')
    plt.colorbar(im2, ax=axes[1])
    
    plt.tight_layout()
    plt.suptitle('Data Samples - Task 1 vs Task 2 Heatmaps', y=1.02)
    plt.show()

if __name__ == "__main__":
    main()
