import os

import numpy as np
import torch
import torchvision.datasets as datasets
from robustness.tools import constants, folder
from robustness.tools.helpers import get_label_mapping
from torch.utils.data import Dataset, Subset
from torchvision import transforms

from colored_mnist import ColoredMNIST, load_dataloaders
from group_loader import GroupTest, MultiGroupLoader
from imagenet_dataset import ImageNet


def sparse2coarse(targets):
    """Convert Pytorch CIFAR100 sparse targets to coarse targets.
    Usage:
        trainset = torchvision.datasets.CIFAR100(path)
        trainset.targets = sparse2coarse(trainset.targets)
    """
    coarse_labels = np.array([ 4,  1, 14,  8,  0,  6,  7,  7, 18,  3,  
                               3, 14,  9, 18,  7, 11,  3,  9,  7, 11,
                               6, 11,  5, 10,  7,  6, 13, 15,  3, 15,  
                               0, 11,  1, 10, 12, 14, 16,  9, 11,  5, 
                               5, 19,  8,  8, 15, 13, 14, 17, 18, 10, 
                               16, 4, 17,  4,  2,  0, 17,  4, 18, 17, 
                               10, 3,  2, 12, 12, 16, 12,  1,  9, 19,  
                               2, 10,  0,  1, 16, 12,  9, 13, 15, 13, 
                              16, 19,  2,  4,  6, 19,  5,  5,  8, 19, 
                              18,  1,  2, 15,  6,  0, 17,  8, 14, 13])
    return coarse_labels[targets]

# load data
class IndexedDataset(Dataset):
    def __init__(self, dataset, transform=None, split='train', args=None, group=None):
        if dataset in ['waterbirds', 'celeba']:
            if dataset == 'waterbirds':
                root_path = ''
            else:
                root_path = ''
            if group is None:
                self.dataset = MultiGroupLoader(dataset_root_dir=root_path, 
                                                    train_split=split, transform=transform, dataset=dataset)
                self.group_array = self.dataset.group_array
                self.spurious = self.dataset.spurious
            else:
                self.dataset = GroupTest(dataset_root_dir=root_path,split=split, 
                                             group=group, dataset=dataset)
            self.targets = self.dataset.labels
        elif 'cmnist' in dataset:
            save_dir = './data/cmnist'
            if dataset == 'cmnist':
                if split == 'train':
                    saved_filename = 'train.pth'
                else:
                    saved_filename = 'test.pth'
                
                self.dataset = torch.load(os.path.join(save_dir, saved_filename))
            elif dataset == 'balance_cmnist':
                if split == 'train':
                    saved_filename = f'{dataset}_train_{args.seed}'
                else:
                    saved_filename = f'{dataset}_test_{args.seed}'
                
                if transform is None:
                    saved_filename += '.pt'
                else:
                    saved_filename += '_transform.pt'
                
                if not os.path.exists(os.path.join(save_dir, saved_filename)):
                    os.makedirs(save_dir, exist_ok=True)
                    # set seed
                    np.random.seed(args.seed)
                    torch.manual_seed(args.seed)
                    self.dataset = load_dataloaders(args, split=split, transform=transform)
                    torch.save(self.dataset, os.path.join(save_dir, saved_filename))

                self.dataset = torch.load(os.path.join(save_dir, saved_filename))
                self.targets = self.dataset.targets
                self.group_array = np.array(self.dataset.group_array)

                if split != 'train':
                    group_index = np.arange(len(self.group_array))[self.group_array == group[0] * args.num_classes + group[1]]
                    self.dataset = Subset(self.dataset, group_index)
                    self.group_array = self.group_array[group_index]
                    self.targets = self.targets[group_index]
        elif dataset == 'cifar10':
            self.dataset = datasets.CIFAR10(root='', train=(split=='train'), transform=transform, download=True)
            self.group_array = np.array(self.dataset.targets)
            self.targets = self.dataset.targets
            if split != 'train':
                data_list = []
                # have one dataset for each group
                for i in range(args.num_classes):
                    # append subset of dataset
                    data_list.append(Subset(self.dataset, np.where(self.group_array == i)[0]))
                self.dataset = data_list
        elif dataset == 'imagenet':
            class_ranges = constants.RESTRICTED_IMAGNET_RANGES
            label_mapping = get_label_mapping('restricted_imagenet',
                    class_ranges)
            train_path = '/data/ILSVRC/Data/CLS-LOC/train'
            test_path = '/data/ILSVRC/test'
            if split == 'train':
                self.dataset = folder.ImageFolder(root=train_path, transform=transform,
                                           label_mapping=label_mapping)
                self.group_array = np.array(self.dataset.targets)
                self.targets = self.dataset.targets
            else:
                self.dataset = ImageNet(root=test_path, transform=transform, class_ranges=class_ranges, label_mapping=label_mapping)
                self.group_array = np.array(self.dataset.labels)
                self.targets = self.dataset.labels
            

    def __getitem__(self, index):
            data, target = self.dataset[index]
            # Your transformations here (or set it in CIFAR10)
            return data, target, index

    def __len__(self):
        return len(self.targets)
    

def load_dataset(args, split='train', group=None, augment=False):
    if args.dataset in ['waterbirds', 'celeba']:
        if augment:
            transform = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])
        else:
            transform = transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])
    elif 'cmnist' in args.dataset:
        if augment:
            transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                            transforms.RandomCrop(32, padding=4),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5, 0.5, 0.5),
                                                                 (0.5, 0.5, 0.5))])
        else:
            transform = transforms.Compose([transforms.Resize(32),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5, 0.5, 0.5),
                                                                 (0.5, 0.5, 0.5))])
    
    elif args.dataset == 'cifar100sup':
        if augment:
            transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
            ])
        else:
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
            ])
    elif args.dataset == 'cifar10':
        if augment:
            transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
            ])
        else:
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
            ])
    elif args.dataset == 'imagenet':
        if augment:
            transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])
        else:
            transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])
    else:
        transform = None

    dataset = IndexedDataset(args.dataset, transform=transform, split=split, group=group, args=args)
    return dataset