import torch
import csv
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import train_test_split
import random
from wilds import get_dataset
from typing import Dict, Tuple
from collections import defaultdict, Counter

from PIL import Image

import os

def get_dataloaders(args, cus_indices=None, cds_indices=None):
    # Get dataset object
    datasets_dict = \
        {
            'mnist': (datasets.MNIST,10),
            'cifar10': (datasets.CIFAR10,10),
            'cifar100': (datasets.CIFAR100,100),
            'tiny_imagenet': (datasets.ImageFolder,200),
            'cinic10': (datasets.ImageFolder,10),
            'syn_cifar10': (datasets.ImageFolder,10),
            'cifar10_new': (datasets.ImageFolder,10),
            'mix_cifar10': (datasets.ImageFolder,10),
            'mix_cifar100': (datasets.ImageFolder,100),
            'mix_tiny_imagenet': (datasets.ImageFolder,200),
            'syn_cifar10_examples': (datasets.ImageFolder,8),
            'stl10': (datasets.STL10,10),
            'waterbirds': (get_dataset,2)
        }
    
    if 'syn_cifar10_weight' in args.dataset:
        dataset_class, num_classes = datasets.ImageFolder, 10
    elif 'mix_cifar10_weight' in args.dataset:
        dataset_class, num_classes = datasets.ImageFolder, 10
    elif 'mix_cifar10_steps' in args.dataset:
        dataset_class, num_classes = datasets.ImageFolder, 10
    else:
        dataset_class, num_classes = datasets_dict[args.dataset]

    # Define transforms
    transform = []

    augm_transforms = {
        'ta': transforms.TrivialAugmentWide(),
        'ra': transforms.RandAugment(),
        'aa': transforms.AutoAugment(),
        'none': None
    }

    str_augm = augm_transforms[args.str_augm]
    if str_augm:
        transform.append(str_augm)

    if args.augm:
        if 'vit' in args.model:
            # Adapted from https://github.com/tintn/vision-transformer-from-scratch/tree/main
            transform.extend([
                transforms.ToTensor(),
                transforms.Resize((32, 32)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), interpolation=2, antialias=None),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        else:
            if args.dataset == 'stl10':
                transform.append(transforms.RandomCrop(96, padding=4))
            elif args.dataset == 'waterbirds':
                transform.extend([
                    transforms.Resize(256),
                    transforms.CenterCrop(224)])
            elif args.dataset == 'syn_cifar10':
                transform.extend([
                    transforms.Resize(256),
                    transforms.CenterCrop(224)])
            elif args.dataset == 'syn_cifar10_examples':
                transform.extend([
                    transforms.Resize(256),
                    transforms.CenterCrop(224)])
            elif 'imagenet' not in args.dataset:
                # print("Not ImageNet")
                transform.append(transforms.RandomCrop(32, padding=4))
            else:
                transform.append(transforms.RandomCrop(64, padding=8))
            if args.dataset not in ['mnist', 'waterbirds', 'syn_cifar10', 'syn_cifar10_examples']:
                transform.append(transforms.RandomHorizontalFlip())
    
    if 'vit' in args.model:
        # Adapted from https://github.com/tintn/vision-transformer-from-scratch/tree/main
        test_transform = [
            transforms.ToTensor(),
            transforms.Resize((32, 32)),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    else:
        if args.dataset == 'waterbirds':
            test_transform = [
                transforms.Resize(256),
                transforms.CenterCrop(224), 
                transforms.ToTensor()]
        elif args.dataset == 'syn_cifar10':
            test_transform = [
                transforms.Resize(256),
                transforms.CenterCrop(224), 
                transforms.ToTensor()]
        elif args.dataset == 'syn_cifar10_examples':
            test_transform = [
                transforms.Resize(256),
                transforms.CenterCrop(224), 
                transforms.ToTensor()]
        else:
            test_transform = [transforms.ToTensor()]
        if args.model in ['mlp', 'densenet121', 'efficientb1', 'efficientv2s'] or args.dataset == 'waterbirds' or args.dataset == 'syn_cifar10' or args.dataset == 'syn_cifar10_examples' or args.pretrained:
            print("Normalize data")
            test_transform.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
        transform.extend(test_transform)
    transform = transforms.Compose(transform)

    # Create dataset objects
    upsample_indices, syn_dataset, train_idx = None, None, None
            
    if dataset_class == datasets.ImageFolder:
        if args.dataset == 'tiny_imagenet':
            print('TinyImageNet has a predetermined validation set. Using it as valset.')
            image_path = os.path.join(args.data_path,'tiny_imagenet')
            trainset_og = dataset_class(os.path.join(image_path, 'train'))
        elif args.dataset == 'cinic10':
            print('CINIC10 has a predetermined validation set. Using it as valset.')
            image_path = os.path.join(args.data_path,'cinic10')
            if args.cinic10_enlarge:
                trainset_og = dataset_class(os.path.join(image_path, 'train_val'))
            else:
                trainset_og = dataset_class(os.path.join(image_path, 'train'))
        elif args.dataset == 'syn_cifar10':
            print('Synthetic CIFAR10 has a predetermined validation set. Using it as valset.')
            image_path = os.path.join(args.data_path,'syn_cifar10')
            trainset_og = dataset_class(os.path.join(image_path, 'train'))
        elif args.dataset == 'cifar10_new':
            print('MIX CIFAR10 has a predetermined validation set. Using it as valset.')
            image_path = os.path.join(args.data_path,'cifar10_new')
            trainset_og = dataset_class(os.path.join(image_path, 'train'))
        elif args.dataset == 'mix_cifar10':
            print('MIX CIFAR10 has a predetermined validation set. Using it as valset.')
            image_path = os.path.join(args.data_path,'mix_cifar10')
            trainset_og = dataset_class(os.path.join(image_path, 'train'))
        elif 'mix_cifar10_steps' in args.dataset:
            print('MIX CIFAR10 has a predetermined validation set. Using it as valset.')
            image_path = os.path.join(args.data_path, args.dataset)
            trainset_og = dataset_class(os.path.join(image_path, 'train'))
        elif args.dataset == 'mix_cifar100':
            print('MIX CIFAR100 has a predetermined validation set. Using it as valset.')
            image_path = os.path.join(args.data_path,'mix_cifar100')
            trainset_og = dataset_class(os.path.join(image_path, 'train'))
        elif args.dataset == 'mix_tiny_imagenet':
            print('MIX Tiny ImageNet has a predetermined validation set. Using it as valset.')
            image_path = os.path.join(args.data_path,'mix_tiny_imagenet')
            trainset_og = dataset_class(os.path.join(image_path, 'train'))
        elif args.dataset == 'syn_cifar10_examples':
            print('Synthetic CIFAR10 Examples has a predetermined validation set. Using it as valset.')
            image_path = os.path.join(args.data_path,'syn_cifar10_examples')
            trainset_og = dataset_class(os.path.join(image_path, 'train'))
        elif 'mix_cifar10_weight' in args.dataset:
            print('MIX CIFAR10 has a predetermined validation set. Using it as valset.')
            image_path = os.path.join(args.data_path, args.dataset)
            trainset_og = dataset_class(os.path.join(image_path, 'train'))
        elif 'syn_cifar10_weight' in args.dataset:
            print('Syn CIFAR10 has a predetermined validation set. Using it as valset.')
            image_path = os.path.join(args.data_path, args.dataset)
            trainset_og = dataset_class(os.path.join(image_path, 'train'))
        valset = dataset_class(os.path.join(image_path, 'val'))
        testset = dataset_class(os.path.join(image_path, 'test'))
        trainset = trainset_og
    elif dataset_class == get_dataset:
        dataset = dataset_class(dataset=args.dataset, root_dir=args.data_path, download=True)
        trainset = dataset.get_subset("train")
        valset = dataset.get_subset("val")
        testset = dataset.get_subset("test")
    else:
        if args.dataset == 'stl10':
            trainset_og = dataset_class(args.data_path, split='train', download=True)
        else:
            trainset_og = dataset_class(args.data_path, train=True, download=True) # here to load

        if args.valset:
            gold_train_val_split_path = os.path.join(args.data_path, 'GOLD_train_val_splits', f'GOLD_{args.dataset}.pt')

            if not os.path.exists(gold_train_val_split_path):
                print(f'WARNING: Main train_val_split file for dataset {args.dataset} was not found. Generating a new one and stopping the run.')
                print('Note that this split may differ from that of previous runs. Make sure you want this!')

                os.makedirs(os.path.join(args.data_path, 'GOLD_train_val_splits'), exist_ok=True)

                if args.dataset == 'stl10':
                    train_idx, val_idx = train_test_split(
                        np.arange(len(trainset_og.labels)), test_size=0.1, random_state=args.seed, shuffle=True, stratify=None)
                else:
                    train_idx, val_idx = train_test_split(
                        np.arange(len(trainset_og.targets)), test_size=0.1, random_state=args.seed, shuffle=True, stratify=None)
                
                train_val_split = {
                    'train_idx': train_idx,
                    'val_idx': val_idx
                }

                torch.save(train_val_split, gold_train_val_split_path)

            gold_train_val_split = torch.load(gold_train_val_split_path)
            train_idx, val_idx = gold_train_val_split['train_idx'], gold_train_val_split['val_idx']
            trainset, valset = Subset(trainset_og, train_idx), Subset(trainset_og, val_idx)
        else:
            valset = None
        
        if args.dataset == 'stl10':
            testset = dataset_class(args.data_path, split='test', download=True)
        else:
            testset = dataset_class(args.data_path, train=False, download=True)
            
    if args.noise_percent:
        total_labels = len(trainset.dataset.targets)
        num_to_change = int(total_labels * args.noise_percent / 100)
        indices_to_change = trainset.indices[:num_to_change]

        for idx in indices_to_change:
            current_label = trainset.dataset.targets[idx]
            new_label = random.choice([i for i in range(num_classes) if i != current_label])
            trainset.dataset.targets[idx] = new_label
    
    if args.us:
        if args.us_idx_path:
            upsample_indices = torch.load(args.us_idx_path).cpu().tolist()
        elif args.us_syn_100_100:
            upsample_indices = train_idx.tolist()
        if args.ds_idx_path:
            downsample_indices = torch.load(args.ds_idx_path).cpu().tolist()
            if train_idx is None:
                train_idx = list(range(len(trainset_og)))
            train_idx = list(set(train_idx)-set(downsample_indices))
            trainset = Subset(trainset_og, train_idx)
        if args.us_type == 'syn' and args.us_syn_dataset:
            syn_dataset = SyntheticDataset(os.path.join(args.data_path, args.us_syn_dataset))

    if args.crs:
        if cus_indices: upsample_indices = cus_indices

        if cds_indices:
            if train_idx is None:
                train_idx = list(range(len(trainset_og)))
            train_idx = list(set(train_idx)-set(cds_indices))
            trainset = Subset(trainset_og, train_idx)

    elif args.extra_per_epoch:
        print(f'Randomly upsampling {args.extra_per_epoch} samples from the trainset.')
        for _ in range(args.repeat_random_extra):
            if dataset_class == datasets.ImageFolder:
                upsample_indices = torch.randperm(len(trainset))[:args.extra_per_epoch].cpu().tolist()
            else:
                upsample_indices = np.random.permutation(train_idx)[:args.extra_per_epoch].tolist()

        if args.us_type == 'syn' and args.us_syn_dataset:
            syn_dataset = SyntheticDataset(os.path.join(args.data_path, args.us_syn_dataset))

        upsample_path = os.path.join(args.local_run_path, 'random_us_indices.pt')
        torch.save(torch.tensor(upsample_indices), upsample_path)
        args.us = True
        args.us_idx_path = upsample_path

    # idx=1
    # print(f"Index: {idx}, Filename: {trainset[idx]}")
    trainset = UpsampledDataset(trainset,
                                transform=transform,
                                us_indices=upsample_indices,
                                us_type=args.us_type,
                                us_syn_dataset=syn_dataset,
                                us_syn_only=args.us_syn_only,
                                train_idx=train_idx)
    # print(trainset[0])
    if valset: valset = UpsampledDataset(valset,
                              transform=transforms.Compose(test_transform))
    assert (args.test_distribution_shift == 'none') or (args.dataset == 'waterbirds')
    testset = UpsampledDataset(testset,
                              transform=transforms.Compose(test_transform),
                              shift=args.test_distribution_shift)

    print()
    if args.us:
        if args.us_idx_path:
            print('Upsampling applied to training data.')
        if args.us_type == 'real':
            print('All samples are real.')
        if args.us_type == 'syn':
            if args.us_syn_only:
                print('All samples are synthetic.')
            else:
                print('Base samples are real, upsampled samples are synthetic.')
    elif args.aus:
        print('No upsampling initially applied. Will perform auto-upsampling.')
    else:
        print('No upsampling applied to training data.')
    print()

    print(f'Num train examples:\t {len(trainset)}')
    if valset: print(f'Num val examples:\t {len(valset)}')
    print(f'Num test examples:\t {len(testset)}')
    print()
        
    # Create loaders
    train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=args.shuffle, pin_memory=True, num_workers=4)
    val_loader = DataLoader(valset, batch_size=args.batch_size, shuffle=args.shuffle, pin_memory=True, num_workers=4)
    test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=args.shuffle, pin_memory=True, num_workers=4)

    return train_loader, val_loader, test_loader, num_classes
    
class SyntheticDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data, self.targets, self.orig_indices = self.load_data()

    def load_data(self):
        images = []
        labels = []
        orig_indices = []

        class_to_idx = {label: idx for idx, label in enumerate(sorted(os.listdir(self.root_dir)))}
        
        for label in os.listdir(self.root_dir):
            label_dir = os.path.join(self.root_dir, label)
            
            for img in os.listdir(label_dir):
                img_path = os.path.join(label_dir, img)
                orig_idx = int(img.split('_')[-3])  # Got original index（e.g. '29'）

                images.append(img_path)
                labels.append(class_to_idx[label]) 
                orig_indices.append(orig_idx)

        return images, labels, orig_indices
    

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        label = self.targets[idx]
        orig_idx = self.orig_indices[idx]

        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label, orig_idx, idx

class UpsampledDataset(Dataset):
    def __init__(self,
                 base_dataset,
                 transform=None,
                 us=False,
                 us_indices=None,
                 us_type='real',
                 us_syn_dataset=None,
                 us_syn_only=False,
                 train_idx=None,
                 shift='none'):
        
        self.base_dataset = base_dataset
        if shift == 'minority':
            subset_indices = [i for i, (_, _, metadata) in enumerate(base_dataset) if metadata[0] != metadata[1]]
            self.base_dataset = Subset(base_dataset, subset_indices)
        elif shift == 'spurious':
            subset_indices = [i for i, (_, _, metadata) in enumerate(base_dataset) if metadata[0] == metadata[1]]
            self.base_dataset = Subset(base_dataset, subset_indices)
        elif shift == 'balanced':
            count_group = defaultdict(list)
            
            for idx in range(len(self.base_dataset)):
                _, _, (place, label, _) = self.base_dataset[idx]
                count_group[2 * place.item() + label.item()].append(idx)
            
            min_len = len(self.base_dataset)
            
            for _, count in count_group.items():
                min_len = min(min_len, len(count))

            subset_indices = []

            for _, count in count_group.items():
                subset_indices.extend(count[:min_len])
            
            self.base_dataset = Subset(base_dataset, subset_indices)
        
        if train_idx is None:
            train_idx = list(range(len(self.base_dataset)))

        self.shuffled_to_orig = {k:v for k,v in enumerate(train_idx)}
        # print(self.shuffled_to_orig[1])
        self.orig_to_shuffled = {v:k for k,v in enumerate(train_idx)}
        
        self.transform = transform
        self.total_len = len(self.base_dataset)

        self.us = us
        self.us_indices = us_indices
        self.us_type = us_type
        self.us_syn_dataset = us_syn_dataset
        self.us_syn_only = us_syn_only

        if self.us_indices is not None:
            self.total_len += len(us_indices)
        elif self.us_syn_dataset:
            self.total_len += len(us_syn_dataset)

    def __getitem__(self, idx):
        if not idx < self.total_len:
            raise Exception('Index out of bounds.')

        # Sample for syn only
        if self.us_syn_only:
            orig_idx = self.shuffled_to_orig[idx]
            x, y, _, _ = self.us_syn_dataset[orig_idx]
            if self.transform: x = self.transform(x)
            return x, y, orig_idx, idx

        # Sample from the base dataset
        if idx < len(self.base_dataset):
            orig_idx = self.shuffled_to_orig[idx]
            x, y = self.base_dataset[idx][:2]
            if self.transform: x = self.transform(x)
            # print(f"org:{orig_idx}, now:{idx}")
            return x, y, orig_idx, idx


        # Sample from upsampled region with real data
        if idx >= len(self.base_dataset) and self.us_type=='real':
            orig_idx = self.us_indices[idx - len(self.base_dataset)]
            x, y = self.base_dataset[self.orig_to_shuffled[orig_idx]][:2]
            if self.transform: x = self.transform(x)
            return x, y, orig_idx, idx

        # Sample from upsampled region with syn data with upsample indices
        if idx >= len(self.base_dataset) and self.us_type=='syn' and self.us_indices:
            orig_idx = self.us_indices[idx - len(self.base_dataset)]
            x, y, _, _ = self.us_syn_dataset[orig_idx]
            if self.transform: x = self.transform(x)
            return x, y, orig_idx, idx
        
        # Sample from upsampled region with syn data without upsample indices
        if idx >= len(self.base_dataset) and self.us_type=='syn':
            x, y, orig_idx, _ = self.us_syn_dataset[idx - len(self.base_dataset)]
            if self.transform: x = self.transform(x)
            # x = self.syn_transform(x)
            return x, y, orig_idx, idx

        raise Exception('Uncaught. If you\'re seeing this, something is wrong with the dataset object.') 

    def __len__(self):
        return self.total_len
    
shapes_dict = {
    'mnist': (60000, 1, 28, 28),
    'mnist_binary': (13007, 1, 28, 28),
    'svhn': (73257, 3, 32, 32),
    'cifar10': (50000, 3, 32, 32),
    'cifar10_new': (50000, 3, 32, 32),
    'cifar10_horse_car': (10000, 3, 32, 32),
    'cifar10_dog_cat': (10000, 3, 32, 32),
    'cifar100': (50000, 3, 32, 32),
    'tiny_imagenet': (100000, 3, 64, 64),
    'uniform_noise': (1000, 1, 28, 28),
    'gaussians_binary': (1000, 1, 1, 100),
    'cinic10': (90000, 3, 32, 32),
    'stl10': (5000, 3, 96, 96),
    'waterbirds': (11788, 3, 224, 224),
    'syn_cifar10': (50000, 3, 224, 224),
    'mix_cifar10': (90000, 3, 32, 32),
    'mix_cifar10_steps_10': (61072, 3, 32, 32),
    'mix_cifar10_steps_20': (61072, 3, 32, 32),
    'mix_cifar10_steps_30': (61072, 3, 32, 32),
    'mix_cifar10_steps_40': (61072, 3, 32, 32),
    'mix_cifar10_steps_50': (61072, 3, 32, 32),
    'mix_cifar10_steps_60': (61072, 3, 32, 32),
    'mix_cifar10_steps_70': (61072, 3, 32, 32),
    'mix_cifar10_steps_80': (61072, 3, 32, 32),
    'mix_cifar10_steps_90': (61072, 3, 32, 32),
    'mix_cifar10_steps_100': (61072, 3, 32, 32),
    'mix_cifar10_weight1': (45000, 3, 32, 32),
    'mix_cifar10_weight2': (45000 + 13925 * 1, 3, 32, 32),
    'mix_cifar10_weight3': (45000 + 13925 * 2, 3, 32, 32),
    'mix_cifar10_weight4': (45000 + 13925 * 3, 3, 32, 32),
    'mix_cifar10_weight5': (45000 + 13925 * 4, 3, 32, 32),
    'syn_cifar10_weight2': (13925 * 1, 3, 32, 32),
    'syn_cifar10_weight3': (13925 * 2, 3, 32, 32),
    'syn_cifar10_weight4': (13925 * 3, 3, 32, 32),
    'syn_cifar10_weight5': (13925 * 4, 3, 32, 32),
    'mix_cifar100': (90000, 3, 32, 32),
    'mix_tiny_imagenet': (200000, 3, 64, 64),
    'syn_cifar10_examples': (720, 3, 224, 224)
}


class GroupLabeledDatasetWrapper(Dataset):
    def __init__(
        self, 
        dataset: Dataset,
        group_partition: Dict[Tuple[int, int], int]
    ):
        """
        Initializes a GroupLabeledDataset.

        :param dataset: The underlying dataset.
        :type dataset: torch.utils.data.Dataset

        :param group_partition: The group partition dictionary mapping indices to group labels.
        :type group_partition: Dict[Tuple[int, int], int]
        """
        self.dataset = dataset
        self.group = torch.zeros(len(self.dataset))
        self.group_partition = group_partition
        
        group_idx = 0
        for key in sorted(group_partition.keys()):
            self.group[group_partition[key]] = group_idx
            group_idx += 1 
        self.num_groups = len(group_partition.keys())
        self.group = self.group.long().tolist()
        
    def __getitem__(self, index: int):
        """
        Retrieves an item from the dataset.

        :param index: The index of the item.
        :type index: int
        :return: The item at the given index.
        """
        source_tuple = self.dataset.__getitem__(index)
        return source_tuple[0], source_tuple[1], source_tuple[2], self.group[source_tuple[3]]
    
    def __len__(self):
        """
        Returns the length of the dataset.

        :return: The length of the dataset.
        :rtype: int
        """
        return len(self.group)
    
    def get_sample_weight(self):
        group_counts = Counter(self.group)
        group_weight = [1. / group_counts.get(i, 1) for i in range(self.num_groups)]
        # print(source_weight)
        sample_weight = [group_weight[_group] for _group in self.group]
        
        return sample_weight