import torch
import csv
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import train_test_split
import random

from PIL import Image

import os

def get_dataloaders(args):
    # Get dataset object
    datasets_dict = \
        {
            'mnist': (datasets.MNIST,10),
            'cifar10': (datasets.CIFAR10,10),
            'cifar100': (datasets.CIFAR100,100),
            'tiny_imagenet': (datasets.ImageFolder,200),
        }
    
    dataset_class, num_classes = datasets_dict[args.dataset]

    # Define transforms
    transform = []

    augm_transforms = {
        'ta': transforms.TrivialAugmentWide(),
        'ra': transforms.RandAugment(),
        'aa': transforms.AutoAugment(),
        'none': None
    }

    str_augm = augm_transforms[args.str_augm]
    if str_augm:
        transform.append(str_augm)

    if args.augm:
        if 'imagenet' not in args.dataset:
            transform.append(transforms.RandomCrop(32, padding=4))
        else:
            transform.append(transforms.RandomCrop(64, padding=8))
        if args.dataset != 'mnist':
            transform.append(transforms.RandomHorizontalFlip())
    transform.append(transforms.ToTensor())
    transform = transforms.Compose(transform)

    # Create dataset objects
    upsample_indices, syn_dataset, train_idx = None, None, None
            
    if dataset_class == datasets.ImageFolder:
        print('TinyImageNet has a predetermined validation set. Using it as valset.')
        tiny_imagenet_path = os.path.join(args.data_path,'tiny_imagenet')
        trainset = dataset_class(os.path.join(tiny_imagenet_path, 'train'))
        valset = dataset_class(os.path.join(tiny_imagenet_path, 'val'))
        testset = dataset_class(os.path.join(tiny_imagenet_path, 'test'))
    else:
        trainset = dataset_class(args.data_path, train=True, download=True)

        if args.valset:
            gold_train_val_split_path = os.path.join(args.data_path, 'GOLD_train_val_splits', f'GOLD_{args.dataset}.pt')

            if not os.path.exists(gold_train_val_split_path):
                print(f'WARNING: Main train_val_split file for dataset {args.dataset} was not found. Generating a new one and stopping the run.')
                print('Note that this split may differ from that of previous runs. Make sure you want this!')

                os.makedirs(os.path.join(args.data_path, 'GOLD_train_val_splits'), exist_ok=True)

                train_idx, val_idx = train_test_split(
                    np.arange(len(trainset.targets)), test_size=0.1, random_state=args.seed, shuffle=True, stratify=None)
                
                train_val_split = {
                    'train_idx': train_idx,
                    'val_idx': val_idx
                }

                torch.save(train_val_split, gold_train_val_split_path)

            gold_train_val_split = torch.load(gold_train_val_split_path)
            train_idx, val_idx = gold_train_val_split['train_idx'], gold_train_val_split['val_idx']
            trainset, valset = Subset(trainset, train_idx), Subset(trainset, val_idx)
        else:
            valset = None

        testset = dataset_class(args.data_path, train=False, download=True)

    if args.noise_percent and not args.aus:
        labels = [trainset.dataset.targets[idx] for idx in trainset.indices]
        num_to_change = int(len(labels) * args.noise_percent/100)
        indices_to_change = random.sample(range(len(labels)), num_to_change)

        for idx in indices_to_change:
            trainset.dataset.targets[trainset.indices[idx]] = \
                random.choice([i for i in range(num_classes) if i != labels[idx]])
    
    if args.us:
        if args.us_idx_path:
            upsample_indices = torch.load(args.us_idx_path).cpu().tolist()
        elif args.us_syn_100_100:
            upsample_indices = train_idx.tolist()
        if args.us_type == 'syn' and args.us_syn_dataset:
            syn_dataset = SyntheticDataset(os.path.join(args.data_path, args.us_syn_dataset))

    elif args.extra_per_epoch:
        print(f'Randomly upsampling {args.extra_per_epoch} samples from the trainset.')
        for _ in range(args.repeat_random_extra):
            if dataset_class == datasets.ImageFolder:
                upsample_indices = torch.randperm(len(trainset))[:args.extra_per_epoch].cpu().tolist()
            else:
                upsample_indices = np.random.permutation(train_idx)[:args.extra_per_epoch].tolist()

        if args.us_type == 'syn' and args.us_syn_dataset:
            syn_dataset = SyntheticDataset(os.path.join(args.data_path, args.us_syn_dataset))

        upsample_path = os.path.join(args.local_run_path, 'random_us_indices.pt')
        torch.save(torch.tensor(upsample_indices), upsample_path)
        args.us = True
        args.us_idx_path = upsample_path

    trainset = UpsampledDataset(trainset,
                                transform=transform,
                                us_indices=upsample_indices,
                                us_type=args.us_type,
                                us_syn_dataset=syn_dataset,
                                us_syn_only=args.us_syn_only,
                                train_idx=train_idx)
    if valset: valset = UpsampledDataset(valset,
                              transform=transforms.Compose([transforms.ToTensor()]))
    testset = UpsampledDataset(testset,
                              transform=transforms.Compose([transforms.ToTensor()]))

    print()
    if args.us:
        if args.us_idx_path:
            print('Upsampling applied to training data.')
        if args.us_type == 'real':
            print('All samples are real.')
        if args.us_type == 'syn':
            if args.us_syn_only:
                print('All samples are synthetic.')
            else:
                print('Base samples are real, upsampled samples are synthetic.')
    elif args.aus:
        print('No upsampling initially applied. Will perform auto-upsampling.')
    else:
        print('No upsampling applied to training data.')
    print()

    print(f'Num train examples:\t {len(trainset)}')
    if valset: print(f'Num val examples:\t {len(valset)}')
    print(f'Num test examples:\t {len(testset)}')
    print()
        
    # Create loaders
    train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=args.shuffle, pin_memory=True, num_workers=4)
    val_loader = DataLoader(valset, batch_size=args.batch_size, shuffle=args.shuffle, pin_memory=True, num_workers=4)
    test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=args.shuffle, pin_memory=True, num_workers=4)

    return train_loader, val_loader, test_loader, num_classes
    
class SyntheticDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data, self.targets, self.orig_indices = self.load_data()

    def load_data(self):
        images = []
        labels = []
        orig_indices = []

        idx_to_labels = {}
        with open(os.path.join(self.root_dir, 'labels.csv'), 'r') as f:
            reader = csv.reader(f)
            next(reader)
            for row in reader:
                idx_to_labels[int(row[0])] = int(row[1])

        for img in os.listdir(os.path.join(self.root_dir, 'images')):
            orig_idx = int(os.path.splitext(img)[0])
            images.append(os.path.join(self.root_dir, 'images', img))
            labels.append(idx_to_labels[orig_idx])
            orig_indices.append(orig_idx)

        return images, labels, orig_indices

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        label = self.targets[idx]
        orig_idx = self.orig_indices[idx]

        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label, orig_idx, idx

class UpsampledDataset(Dataset):
    def __init__(self,
                 base_dataset,
                 transform=None,
                 us=False,
                 us_indices=None,
                 us_type='real',
                 us_syn_dataset=None,
                 us_syn_only=False,
                 train_idx=None):
        
        self.base_dataset = base_dataset
        self.transform= transform

        self.total_len = len(base_dataset)

        self.us = us
        self.us_indices = us_indices
        self.us_type = us_type
        self.us_syn_dataset = us_syn_dataset
        self.us_syn_only = us_syn_only

        if train_idx is None:
            train_idx = list(range(len(base_dataset)))

        self.shuffled_to_orig = {k:v for k,v in enumerate(train_idx)}
        self.orig_to_shuffled = {v:k for k,v in enumerate(train_idx)}

        if self.us_indices is not None:
            self.total_len += len(us_indices)
        elif self.us_syn_dataset:
            self.total_len += len(us_syn_dataset)

    def __getitem__(self, idx):
        if not idx < self.total_len:
            raise Exception('Index out of bounds.')

        # Sample for syn only
        if self.us_syn_only:
            orig_idx = self.shuffled_to_orig[idx]
            x, y, _, _ = self.us_syn_dataset[orig_idx]
            if self.transform: x = self.transform(x)
            return x, y, orig_idx, idx

        # Sample from the base dataset
        if idx < len(self.base_dataset):
            orig_idx = self.shuffled_to_orig[idx]
            x, y = self.base_dataset[idx]
            if self.transform: x = self.transform(x)
            return x, y, orig_idx, idx

        # Sample from upsampled region with real data
        if idx >= len(self.base_dataset) and self.us_type=='real':
            orig_idx = self.us_indices[idx - len(self.base_dataset)]
            x, y = self.base_dataset[self.orig_to_shuffled[orig_idx]]
            if self.transform: x = self.transform(x)
            return x, y, orig_idx, idx

        # Sample from upsampled region with syn data with upsample indices
        if idx >= len(self.base_dataset) and self.us_type=='syn' and self.us_indices:
            orig_idx = self.us_indices[idx - len(self.base_dataset)]
            x, y, _, _ = self.us_syn_dataset[orig_idx]
            if self.transform: x = self.transform(x)
            return x, y, orig_idx, idx
        
        # Sample from upsampled region with syn data without upsample indices
        if idx >= len(self.base_dataset) and self.us_type=='syn':
            x, y, orig_idx, _ = self.us_syn_dataset[idx - len(self.base_dataset)]
            if self.transform: x = self.transform(x)
            return x, y, orig_idx, idx

        raise Exception('Uncaught. If you\'re seeing this, something is wrong with the dataset object.') 

    def __len__(self):
        return self.total_len
    
shapes_dict = {
    'mnist': (60000, 1, 28, 28),
    'mnist_binary': (13007, 1, 28, 28),
    'svhn': (73257, 3, 32, 32),
    'cifar10': (50000, 3, 32, 32),
    'cifar10_horse_car': (10000, 3, 32, 32),
    'cifar10_dog_cat': (10000, 3, 32, 32),
    'cifar100': (50000, 3, 32, 32),
    'tiny_imagenet': (100000, 3, 64, 64),
    'uniform_noise': (1000, 1, 28, 28),
    'gaussians_binary': (1000, 1, 1, 100),
}