import os
import shutil
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
from utils.folder import ImageFolder, ImageFolderNoLabel
import numpy as np
# import cv2
import torchvision.transforms.functional as TF
import random

def generate_dataloader(args):
    # Data loading code
    traindir = os.path.join(args.data_path_source, args.src)
    traindir_t = os.path.join(args.data_path_target, args.tar)
    valdir = os.path.join(args.data_path_target, args.tar)
    valdir_t = os.path.join(args.data_path_target_t, args.tar_t)
    
    classes = os.listdir(traindir)
    classes.sort()
    # ins_num_for_each_cls_src = torch.cuda.FloatTensor(args.num_classes)
    ins_num_for_each_cls_src = torch.empty(args.num_classes, dtype=torch.float32, device='cuda')
    for i,c in enumerate(classes):
        ins_num_for_each_cls_src[i] = len(os.listdir(os.path.join(traindir, c)))
    
    if not os.path.isdir(traindir):
        raise ValueError ('the require data path is not exist, please download the dataset')

    if args.no_da:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    else:
        # transformation on the training data during training
        data_transform_train = transforms.Compose([
                resize_policy(256, args),
      			transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
      			transforms.ToTensor(),
      			transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
      	])
        # transformation on the duplicated data during training
        data_transform_train_dup = get_augmentation(args.aug_dup_idx, args)
        # transformation on the second augmentation of target data during training
        data_transform_train_gray = get_augmentation(args.aug_gray_idx, args)
        # transformation on the test data during test
        data_transform_test = transforms.Compose([
                resize_policy(256, args),
      			transforms.CenterCrop(224),
      			transforms.ToTensor(),
      			transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
      	])

    source_train_transform = data_transform_train if args.src_aug_idx == None else get_augmentation(args.src_aug_idx, args)
    data_transform_base = data_transform_train if args.base_aug_idx == None else get_augmentation(args.base_aug_idx, args)
    data_transform_test = data_transform_test if args.test_trans_idx == None else get_augmentation(args.test_trans_idx, args)
    source_data_transform_test = data_transform_test if args.test_trans_idx == None else get_augmentation(args.source_test_trans_idx, args)
    
    print('data_transform_base:', data_transform_base)
    print('data_transform_train_dup:', data_transform_train_dup)
    print('data_transform_train_gray:', data_transform_train_gray)
    print('source_train_transform:', source_train_transform)
    print('data_transform_test:', data_transform_test)
    print('source_data_transform_test:', source_data_transform_test)
    
    source_train_dataset = ImageFolder(root=traindir, transform=source_train_transform)
    source_test_dataset = ImageFolder(root=traindir, transform=source_data_transform_test)
    if args.aug_tar_agree and (not args.gray_tar_agree):
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    elif args.gray_tar_agree and (not args.aug_tar_agree):
        target_train_dataset = ImageFolder(root=traindir_t, transform=data_transform_base, transform_gray=data_transform_train_gray)
    elif args.aug_tar_agree and args.gray_tar_agree:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    else:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    target_test_dataset = ImageFolder(root=valdir, transform=data_transform_test)
    target_test_dataset_t = ImageFolder(root=valdir_t, transform=data_transform_test)
    
    source_train_loader = FastDataLoader(
        source_train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, sampler=None, drop_last=True
    )
    source_test_loader = FastDataLoader(
        source_test_dataset, batch_size=63, shuffle=False,
        num_workers=args.workers, pin_memory=True
    )
    target_train_loader = FastDataLoader(
        target_train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, sampler=None, drop_last=True
    )
    target_test_loader = FastDataLoader(
        target_test_dataset, batch_size=63, shuffle=False,
        num_workers=args.workers, pin_memory=True
    )
    target_test_loader_t = FastDataLoader(
        target_test_dataset_t, batch_size=63, shuffle=False,
        num_workers=args.workers, pin_memory=True
    )
    print('num data and num batches in source_train_loader:', len(source_train_loader.dataset), len(source_train_loader))
    print('num data and num batches in target_train_loader:', len(target_train_loader.dataset), len(target_train_loader))
    return source_train_loader, target_train_loader, target_test_loader, target_test_loader_t, source_test_loader

def _random_affine_augmentation(x):
    raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper


def _gaussian_blur(x, sigma=0.1):

    raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper

class _RepeatSampler(object):
    """ Sampler that repeats forever.

    Args:
        sampler (Sampler)
    """

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            yield from iter(self.sampler)

class FastDataLoader(torch.utils.data.dataloader.DataLoader):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
        self.iterator = super().__iter__()

    def __len__(self):
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        for i in range(len(self)):
            yield next(self.iterator)
        
def get_augmentation(augmentation_idx, args):
    stats = norm_stats(args)
    aug0 = transforms.Compose([
                resize_policy(256, args),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.RandomAffine(degrees=15, translate=(0, 0), scale=(0.8, 1.2), shear=3,
                                        interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.ColorJitter((0.8, 1.2), (0.8, 1.2), (0.8, 1.2), (-0.25, 0.25)),
                transforms.GaussianBlur(kernel_size=1, sigma=args.sigma),
     			transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        # transformation on the grayscale data during training
    aug1 = transforms.Compose([
                transforms.Grayscale(3),
                resize_policy(256, args),
      			transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
      			transforms.ToTensor(),
      			transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
      	])
    aug2 = transforms.Compose([
                transforms.Grayscale(num_output_channels=1),
                resize_policy(24, args),
      			transforms.ToTensor(),
      			transforms.Normalize(mean=stats[0][0], std=stats[1][0]),
      	])
    aug3 = transforms.Compose([
                transforms.Grayscale(num_output_channels=1),
                resize_policy(24, args),
                transforms.ToTensor(),
                transforms.RandomAffine(degrees=15, translate=(0.2, 0.2), scale=(0.8, 1.2), shear=3,
                                    interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.ColorJitter((0.8, 1.2), (0.8, 1.2), (0.8, 1.2), (-0.25, 0.25)),
      			transforms.Normalize(mean=stats[0][0], std=stats[1][0]),
        ])
    aug4 = transforms.Compose([
                resize_policy(32, args),
      			transforms.ToTensor(),
      			transforms.Normalize(mean=stats[0], std=stats[1]),
      	])
    aug5 = transforms.Compose([
                transforms.Grayscale(3),
                resize_policy(32, args),
      			transforms.ToTensor(),
      			transforms.Normalize(mean=stats[0], std=stats[1]),
      	])
    
    aug6 = transforms.Compose([
                resize_policy(32, args),
                transforms.ToTensor(),
                transforms.RandomAffine(degrees=15, translate=(0.2, 0.2), scale=(0.8, 1.2), shear=3,
                                    interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.ColorJitter((0.8, 1.2), (0.8, 1.2), (0.8, 1.2), (-0.25, 0.25)),
      			transforms.Normalize(mean=stats[0], std=stats[1]),
        ])
    aug7 = RandomChoiceAugmentation(aug5, aug6, p=0.5)

    augmentation_list = [aug0, aug1, aug2, aug3, aug4, aug5, aug6, aug7]
    
    return augmentation_list[augmentation_idx]
 
def resize_policy(size, args):
    if args.resize_to_square and (args.rtsquare_intpl == 'nearest' or args.rtsquare_intpl == None):
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    elif args.resize_to_square:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    
    if args.rtsquare_intpl == 'nearest':
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    if args.rtsquare_intpl == 'bilinear':
        return transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR)
    if args.rtsquare_intpl == 'lanczos':
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    elif args.rtsquare_intpl == None:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    
class RandomChoiceAugmentation:
    def __init__(self, aug_a, aug_b, p=0.5):
        self.aug_a = aug_a
        self.aug_b = aug_b
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return self.aug_a(img)
        else:
            return self.aug_b(img)

def norm_stats(args):
    if args.ns == 'imagenet':
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        return mean, std
    elif args.ns == 'src':
        stats_based_on = args.src
    else:
        raise(BaseException("--ns should be 'imagenet' or 'src' (source)"))
    
    if stats_based_on == 'mnist':
        mean = [0.148, 0.148, 0.148]
        std = [0.271, 0.271, 0.271]
    elif stats_based_on == 'usps':
        mean = [0.28, 0.28, 0.28]
        std = [0.322, 0.322, 0.322]
    elif stats_based_on == 'mnist_train':
        mean = [0.135, 0.135, 0.135]
        std = [0.295, 0.295, 0.295]
    elif stats_based_on == 'usps_train':
        mean = [0.249, 0.249, 0.249]
        std = [0.294, 0.294, 0.294]
    elif stats_based_on == 'svhn_train':
        mean = [0.438, 0.444, 0.473]
        std = [0.195, 0.198, 0.197]
    else:
        raise(BaseException("Statistics of other domains have not been included yet. Please include then use them."))
    
    return mean, std
        