import numpy as np
import math

import torch
from torch.utils.data import DataLoader, ConcatDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import Sampler
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from data.svhn import RobustSubsets, RobustTransform
from data.cure_tsr import CUREDataset
from data.gtsrb import GTSRBSubsets



def get_imagenet_loaders(args):

    train_tfms = transforms.Compose([
            transforms.RandomResizedCrop(args.data_size, scale=(0.08, 1.0)),
            transforms.RandomHorizontalFlip()
    ])
    train_dataset = datasets.ImageFolder(args.train_data_dir, train_tfms)
    train_sampler = DistributedSampler(train_dataset) if args.distributed is True else None

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, collate_fn=fast_collate, 
        sampler=train_sampler)

    val_tfms = transforms.Compose([
        transforms.Resize(int(args.data_size*1.14)), 
        transforms.CenterCrop(args.data_size)
    ])
    val_dataset = datasets.ImageFolder(args.val_data_dir, val_tfms)
    val_sampler = DistValSampler(list(range(len(val_dataset))), args)

    val_loader = DataLoader(
        val_dataset,
        num_workers=args.workers, pin_memory=True, collate_fn=fast_collate, 
        batch_sampler=val_sampler)

    train_loader = BatchTransformDataLoader(train_loader, half_prec=args.half_prec)
    val_loader = BatchTransformDataLoader(val_loader, half_prec=args.half_prec)

    return train_loader, val_loader, train_sampler, val_sampler

def get_svhn_loaders(args):

    train_dataset = RobustSubsets('train', 'brightness', dom='low', args=args)
    # train_dataset = RobustTransform('train', 'hue', args=args)
    
    train_sampler = DistributedSampler(train_dataset) if args.distributed is True else None

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, # collate_fn=fast_collate, 
        sampler=train_sampler)

    val_dataset = RobustSubsets('test', 'brightness', dom='high', args=args)
    # val_dataset = RobustTransform('test', 'hue', args=args)

    val_sampler = DistributedSampler(val_dataset) if args.distributed is True else None

    val_loader = DataLoader(
        val_dataset, batch_size=args.batch_size,
        num_workers=args.workers, pin_memory=True, # collate_fn=fast_collate, 
        sampler=val_sampler)

    return train_loader, val_loader, train_sampler, val_sampler

def get_cure_tsr_loaders(args):

    train_dataset = CUREDataset('train', 'rain', level=0, args=args)
    # train_dataset2 = CUREDataset('train', 'decolorization', level=2, args=args)
    # train_dataset = ConcatDataset([train_dataset, train_dataset2])

    train_sampler = DistributedSampler(train_dataset) if args.distributed is True else None
    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_dataset = CUREDataset('test', 'rain', level=5, args=args)
    val_sampler = DistributedSampler(val_dataset) if args.distributed is True else None
    val_loader = DataLoader(
        val_dataset, batch_size=args.batch_size,
        num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    return train_loader, val_loader, train_sampler, val_sampler
    

def get_gtsrb_loaders(args):
    train_dataset = GTSRBSubsets('train', 'contrast', dom='low', args=args)
    train_sampler = DistributedSampler(train_dataset) if args.distributed is True else None
    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_dataset = GTSRBSubsets('test', 'contrast', dom='high', args=args)
    val_sampler = DistributedSampler(val_dataset) if args.distributed is True else None
    val_loader = DataLoader(
        val_dataset, batch_size=args.batch_size,
        num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    return train_loader, val_loader, train_sampler, val_sampler

class BatchTransformDataLoader:
    # Mean normalization on batch level instead of individual
    # https://github.com/NVIDIA/apex/blob/59bf7d139e20fb4fa54b09c6592a2ff862f3ac7f/examples/imagenet/main.py#L222
    def __init__(self, loader, half_prec=True):
        self.loader = loader
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        self.half_prec = half_prec
        if self.half_prec: self.mean, self.std = self.mean.float(), self.std.float()

    def __len__(self): 
        return len(self.loader)

    def process_tensors(self, input, target, non_blocking=True):
        input = input.cuda(non_blocking=non_blocking).float()

        if len(input.shape) < 3: 
            return input, target.cuda(non_blocking=non_blocking)

        return input.sub_(self.mean).div_(self.std), target.cuda(non_blocking=non_blocking)

    def update_batch_size(self, bs):
        self.loader.batch_sampler.batch_size = bs
            
    def __iter__(self):
        return (self.process_tensors(input, target, non_blocking=True) for input,target in self.loader)

def fast_collate(batch):
    if not batch: return torch.tensor([]), torch.tensor([])
    imgs = [img[0] for img in batch]
    targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
    w = imgs[0].size[0]
    h = imgs[0].size[1]
    tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
    for i, img in enumerate(imgs):
        nump_array = np.asarray(img, dtype=np.uint8)
        tens = torch.from_numpy(nump_array)
        if(nump_array.ndim < 3):
            nump_array = np.expand_dims(nump_array, axis=-1)
        nump_array = np.rollaxis(nump_array, 2)
        tensor[i] += torch.from_numpy(nump_array)
    return tensor, targets

class DistValSampler(Sampler):
    # DistValSampler distrbutes batches equally (based on batch size) to every gpu (even if there aren't enough images)
    # WARNING: Some baches will contain an empty array to signify there aren't enough images
    # Distributed=False - same validation happens on every single gpu
    def __init__(self, indices, args):
        self.indices = indices
        self.batch_size = args.batch_size

        if args.distributed:
            self.world_size = args.world_size
            self.global_rank = args.rank
        else: 
            self.global_rank = 0
            self.world_size = 1
            
        # expected number of batches per sample. Need this so each distributed gpu validates on same number of batches.
        # even if there isn't enough data to go around
        self.expected_num_batches = math.ceil(len(self.indices) / self.world_size / self.batch_size)
        
        # num_samples = total images / world_size. This is what we distribute to each gpu
        self.num_samples = self.expected_num_batches * self.batch_size
        
    def __iter__(self):
        offset = self.num_samples * self.global_rank
        sampled_indices = self.indices[offset:offset+self.num_samples]
        for i in range(self.expected_num_batches):
            offset = i*self.batch_size
            yield sampled_indices[offset:offset+self.batch_size]

    def __len__(self): 
        return self.expected_num_batches

    def set_epoch(self, epoch): 
        return
    



