import os
import torch
import platform
import numpy as np
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, SVHN, ImageFolder, FashionMNIST
import torch.nn.functional as F
# for testing the code
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from data.tiny_imagenet import TinyImageNet  # Fixed for remote execution


cDT = {'cifar10': CIFAR10,
       'cifar100': CIFAR100,
       'svhn': SVHN,
       'mnist': MNIST,
       'imagenet': TinyImageNet,  # Use custom TinyImageNet loader
       'fmnist': FashionMNIST
       }
DATASETS = ['cifar10', 'cifar100', 'svhn', 'mnist', 'fmnist', 'imagenet']
data_path = {'cifar10': "./data/cifar10/" if platform.system()=="Windows" else '~/data/cifar10/',
            'cifar100': "./data/cifar100/" if platform.system()=="Windows" else '~/data/cifar100/',
            'svhn': "./data/svhn/" if platform.system()=="Windows" else '~/data/svhn/',
            'mnist': "./data/mnist/" if platform.system() == "Windows" else '~/data/mnist/',
            'imagenet': "./data/imagenet" if platform.system()=="Windows" else '~/data/tiny-imagenet-200/',
            'fmnist': "./data/fmnist/" if platform.system()=="Windows" else '~/data/fmnist/'
             }
nDT = {'cifar10': 50000,
       'cifar100': 50000,
       'svhn': 73257,
       'mnist': 60000,
       'fmnist': 60000,
       'imagenet': 100000
       }
n_cls = {'cifar10': 10,
       'cifar100': 100,
       'svhn': 10,
       'mnist' : 10,
       'fmnist': 10,
       'imagenet': 200
       }

normalization_infos = {
    'cifar10' : [(0.4914, 0.4822, 0.4465),
                 (0.2023, 0.1994, 0.2010)], # MEAN, STD
    'cifar100' : [(0.5071, 0.4867, 0.4408),
                  (0.2675, 0.2565, 0.2761)], # MEAN, STD
    'svhn' : [(0.4376821, 0.4437697, 0.47280442),
              (0.19803012, 0.20101562, 0.19703614)],
    'mnist' : [(0.1307,), (0.3081,)],
    'fmnist': [(0.2860), (0.3530)],
    'imagenet': [(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)]
}
sh = {'cifar10': (3, 32, 32),
       'cifar100': (3, 32, 32),
       'svhn': (3, 32, 32),
       'mnist' : (1, 28, 28),
        'imagenet': (3, 64, 64),
        'fmnist': (1, 28, 28)
    }
input_range = {'cifar10': (-1, 1),
       'cifar100': (-1, 1),
       'svhn': (-1, 1),
       'mnist' : (-1, 1),
       'fmnist': (-1, 1),
       'imagenet': (-1, 1)}

def normalize(dt):
    mean, std = normalization_infos[dt]
    return transforms.Normalize(mean=mean, std=std)


def get_transform(args, dt, sigma=0.):
    tr_lst = [transforms.ToTensor()]
    if not args.at:
        tr_lst.append(normalize(dt))
    tr_lst.append(lambda x: x + sigma * torch.randn_like(x))
    transform_test = transforms.Compose(tr_lst)
    return transform_test


def transform(args, dt, type):
    if type == 'tr':
        if dt == 'mnist' or dt == 'fmnist':
            tr_lst = list()
        elif dt == 'svhn':
            tr_lst = [transforms.Pad(4, padding_mode="reflect"),
                      transforms.ColorJitter(brightness=63. / 255., saturation=[0.5, 1.5], contrast=[0.2, 1.8]),
                      transforms.RandomCrop(32) ]
        elif dt == 'imagenet':
            tr_lst = [
                transforms.RandomResizedCrop(size=(56, 56)),
                transforms.RandomHorizontalFlip(),
            ]
        # TODO: add more augmentation
        else:
            tr_lst = [transforms.Pad(4, padding_mode="reflect"),
                      transforms.RandomCrop(32),
                      transforms.RandomHorizontalFlip()]
        tr_lst.append(transforms.ToTensor())
        tr_lst.append(normalize(dt))
        # if not args.at:
        #     tr_lst.append(normalize())
        # if args.x_noise:
        #     tr_lst.append(lambda x: x + args.x_sigma * torch.randn_like(x))
        transform = transforms.Compose(tr_lst)
    elif type == 'vl':
        tr_lst = []
        if dt == 'imagenet':
            tr_lst = [
                transforms.Resize(size=(56, 56)),
                transforms.ToTensor()
            ]
        else:
            tr_lst = [transforms.ToTensor()]
        # TODO: corrupt with another distribution of noise
        if args.vl_corrupt_ns:
            tr_lst.append(lambda x: x + args.x_sigma * torch.randn_like(x))
        # if not args.at:
        #     tr_lst.append(normalize())
        tr_lst.append(normalize(dt))
        transform = transforms.Compose(tr_lst)
    else:
        raise ValueError

    return transform

def get_dl_tr(args, shuffle=True, no_aug=False):
    return _get_dl(args, True, shuffle, no_aug)


def get_dl_vl(args, shuffle=False, no_aug=False):
    return _get_dl(args, False, shuffle, no_aug)


'''
def get_dl_robust(args, no_aug=False):
    from data.ds import Robust_ds
    ds = Robust_ds(args)
    return DataLoader(
        ds, batch_size=args.bsz, shuffle=False)
'''


def _get_dl(args, is_train, shuffle, no_aug, is_uc=False):
    sdt = args.dataset
    bsz = args.bsz if is_train else args.bsz_vl
    tr_vl = 'tr' if is_train and not no_aug else 'vl'
    # TODO : download=False if the dataset is already downloaded
    # download = False if os.path.exists(data_path[sdt]) else True
    if sdt == 'svhn':
        # Added here for safety (HJ)
        if args.stability and tr_vl == 'tr' and args.z_init == 'use_db':
            s = ['test', 'train'][is_train]
            pth = 'train/' if is_train else 'val/'
            dt = dataset_with_indices(cDT[sdt])(
                root=data_path[sdt]+pth, split=s,
                download=True, transform=transform(args, sdt, tr_vl))
        else:
            s = ['test', 'train'][is_train]
            dt = cDT[sdt](root=data_path[sdt], download=True,
                        transform=transform(args, sdt, tr_vl),
                        split=s, )


    elif sdt == 'imagenet':
        # Use custom TinyImageNet loader with consistent class mapping
        dt = cDT[sdt](root=os.path.expanduser(data_path[sdt]),
                      train=is_train,
                      transform=transform(args, sdt, tr_vl))
    elif sdt == 'mnist' or sdt == 'cifar10' or sdt == 'fmnist' or sdt == 'cifar100':
        pth = 'train/' if is_train else 'val/'
        if args.stability and tr_vl == 'tr' and args.z_init == 'use_db':
            dt = dataset_with_indices(cDT[sdt])(
                root=data_path[sdt]+pth, train=is_train,
                download=True, transform=transform(args, sdt, tr_vl))
        else:
            dt = cDT[sdt](root=data_path[sdt]+pth, train=is_train,
                download=True, transform=transform(args, sdt, tr_vl))

    if args.debug:
        dt.data = dt.data[:64]

    dl = torch.utils.data.DataLoader(
        dt,
        drop_last=True,
        sampler=None,
        batch_size=bsz,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2,
        shuffle=shuffle,
    )
    print(f"Data loaded with {len(dt)} {tr_vl} imgs.")
    return len(dt), dl


def cycle(loader):
    while True:
        for data in loader:
            yield data


def dataset_with_indices(cls):
    """
    Modifies the given Dataset class to return a tuple data, target, index
    instead of just data, target.
    """

    def __getitem__(self, index):
        data, pid, trial, target = cls.__getitem__(self, index)
        return data, pid, trial, target, index

    return type(cls.__name__, (cls,), {
        '__getitem__': __getitem__,
    })

def get_dataloader(args):
    """
    Returns data loader for PCN
    Adapts the existing dl_getter interface to match PCN args
    """
    # Convert PCN args to dl_getter args format
    class DLGetterArgs:
        def __init__(self, pcn_args):
            # Map PCN args to dl_getter required attributes
            self.dataset = pcn_args.dataset.lower()  # Convert to lowercase
            self.bsz = pcn_args.batch_size  # Training batch size
            self.bsz_vl = pcn_args.batch_size  # Validation batch size (set to same)
            self.at = False  # Disable adversarial training
            self.debug = getattr(pcn_args, 'debug', False)
            self.stability = False  # Disable stability features for PCN
            self.vl_corrupt_ns = False  # Disable validation data noise
            self.z_init = getattr(pcn_args, 'z_init', 'ff')  # PCN's z_init setting

    # Convert args
    dl_args = DLGetterArgs(args)

    try:
        # Training data loader
        train_size, train_loader = get_dl_tr(dl_args, shuffle=True, no_aug=False)

        # Validation data loader (test data)
        test_size, test_loader = get_dl_vl(dl_args, shuffle=False, no_aug=True)

        print(f"Loaded dataset '{args.dataset}': Train={train_size}, Test={test_size}")

        return {
            'train_loader': train_loader,
            'test_loader': test_loader,
            'train_size': train_size,
            'test_size': test_size
        }

    except Exception as e:
        print(f"Error loading dataset '{args.dataset}': {e}")
        print(f"Available datasets: {list(cDT.keys())}")
        raise

def get_dataset_info(dataset_name):
    """
    Returns dataset information
    """
    dataset_name = dataset_name.lower()

    if dataset_name not in sh:
        raise ValueError(f"Dataset '{dataset_name}' not supported. Available: {list(sh.keys())}")

    return {
        'num_classes': n_cls.get(dataset_name, 10),
        'img_shape': sh[dataset_name],
        'input_range': input_range[dataset_name],
        'normalization': normalization_infos.get(dataset_name, [(0.5,), (0.5,)])
    }

# Update supported dataset list
SUPPORTED_DATASETS_PCN= ['mnist', 'cifar10', 'cifar100', 'svhn', 'fmnist', 'imagenet']


if __name__ == '__main__':
    from collections import namedtuple
    # attack, x_noise, x_sigma, dataset, batch size, validation batch size, debug, balanced sampling
    Args = namedtuple('Args', ['at', 'x_noise', 'x_sigma', 'dataset',
                            'bsz', 'bsz_vl', 'debug', 'bs', 'stability', 'vl_corrupt_ns'])
    args = Args(at=False, x_noise=False, x_sigma=0, dataset='fmnist',
                bsz=64, bsz_vl=256, debug=False, bs=False, stability=True, vl_corrupt_ns=False)
    tr_dl = get_dl_tr(args)
    vl_dl = get_dl_vl(args)
