import torch
import numpy as np
import random
import pickle
from pathlib import Path
from typing import Literal
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.datasets as dset
from sklearn.decomposition import PCA
import os 
from scipy.io import loadmat
import pandas as pd
import torchvision.transforms as trn
from torchvision.transforms import AutoAugment, AutoAugmentPolicy
from enum import Enum 
import jax.numpy as jnp
import medmnist
from typing import Literal
from medmnist import INFO, Evaluator
from datasets import load_dataset
from torch.utils.data import ConcatDataset, random_split

def set_random_seeds(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

SUPPORTED_DATASETS = {'mnist', 'fmnist', 'organamnist','kmnist'}


class ToHWCTransform:
    def __call__(self, tensor):
        # Transpose from [C, H, W] to [H, W, C]
        return tensor.permute(1, 2, 0)

class ToJAXArray:
    def __call__(self, tensor):
        # Convert PyTorch tensor to NumPy array, then to JAX array
        return jnp.array(tensor.numpy())

class PrintShape:
    def __call__(self, array):
        print(f"Data shape after transform: {array.shape}")
        return array


def get_transform(dataset_name: str):
    if dataset_name in ['mnist', 'fmnist', 'kmnist']:
        return trn.Compose([
            trn.ToTensor(),
            trn.Normalize((0.5,), (0.5,))
        ])
    elif dataset_name in ['organamnist']:
        return trn.Compose([
            trn.RandomHorizontalFlip(),
            trn.RandomRotation(15),
            trn.ToTensor(),
            trn.Normalize((0.5,), (0.5,)),
        ])


def get_dataset(dataset_name: str, data_dir: Path = Path('data'), mode: str = 'train'):
    if dataset_name not in SUPPORTED_DATASETS:
        raise ValueError(f"Unsupported dataset name: {dataset_name}")

    transform = get_transform(dataset_name)
    if dataset_name == 'mnist':
        dataset_class = dset.MNIST
    elif dataset_name == 'fmnist':
        dataset_class = dset.FashionMNIST
    elif dataset_name == 'kmnist':
        dataset_class = dset.KMNIST 
    elif dataset_name == 'organamnist':
        info = INFO[dataset_name]
        DataClass = getattr(medmnist, info['python_class'])
        return DataClass(split=mode, transform=transform, download=True)
    
    dataset_args = {'download': True, 'train': mode == 'train'}
    
    return dataset_class(root=data_dir, transform=transform, **dataset_args)

def load_data(config, seed, batch_size=500):
    set_random_seeds(seed)
    print(f"Loading data for dataset: {config.dataset_name}")

    if config.dataset_name in ['mnist', 'fmnist', 'kmnist']:
        train = get_dataset(config.dataset_name, mode = 'train')
        print("length of kmnsit train", len(train))
        train, calib = random_split(train, [config.train_size, config.calib_size])
        test = get_dataset(config.dataset_name, mode='test')
    
    elif config.dataset_name in ['organamnist']:
        transform = get_transform(config.dataset_name)
        data_flag = config.dataset_name
        info = INFO[data_flag]
        DataClass = getattr(medmnist, info['python_class'])

        train = DataClass(split='train', transform=transform, download=True)
        calib = DataClass(split='val', transform=transform, download=True)
        test = DataClass(split='test', transform=transform, download=True)


    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, drop_last = False)
    calib_loader = DataLoader(calib, batch_size=batch_size, shuffle=False, drop_last = False)
    test_loader = DataLoader(test, batch_size=batch_size, shuffle=False, drop_last = False)

    return train_loader, calib_loader, test_loader