
import numpy as np
import matplotlib as mpl

import os, sys, math, random, tarfile, glob, time, itertools
import parse

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision import transforms, utils

from omnidata.taskonomy_replica_gso_dataset import TaskonomyReplicaGsoDataset, REPLICA_BUILDINGS

from utils import *
from logger import Logger, VisdomLogger
from task_configs import get_task, tasks

from PIL import Image
from io import BytesIO
from sklearn.model_selection import train_test_split
import IPython

import pdb

""" Default data loading configurations for training, validation, and testing. """
def load_train_val(train_tasks, val_tasks=None, fast=False,
        train_buildings=None, val_buildings=None, split_file="config/split.txt",
        dataset_cls=None, batch_size=32, batch_transforms=cycle,
        subset=None, subset_size=None, dataaug=False,
    ):

    dataset_cls = dataset_cls or TaskDataset
    train_cls = TrainTaskDataset if dataaug else dataset_cls
    train_tasks = [get_task(t) if isinstance(t, str) else t for t in train_tasks]
    if val_tasks is None: val_tasks = train_tasks
    val_tasks = [get_task(t) if isinstance(t, str) else t for t in val_tasks]
    data = yaml.safe_load(open(split_file))
    # train_buildings = train_buildings or (["almena"] if fast else data["train_buildings"])
    # val_buildings = val_buildings or (["almena"] if fast else data["val_buildings"])
    print("number of train images:")
    # train_loader = train_cls(buildings=train_buildings, tasks=train_tasks)
    train_tasks = [t.name for t in train_tasks] + ['mask_valid']
    train_loader = omnidata_dataset_train(tasks=train_tasks,image_size=256)
    print("number of val images:")
    # val_loader = dataset_cls(buildings=val_buildings, tasks=val_tasks)
    val_loader, val_imgs = omnidata_dataset_val(tasks=train_tasks,image_size=256)


    if subset_size is not None or subset is not None:
        train_loader = torch.utils.data.Subset(train_loader,
            random.sample(range(len(train_loader)), subset_size or int(len(train_loader)*subset)),
        )

    train_step = int(len(train_loader) // (400 * batch_size))
    val_step = int(len(val_loader) // (400 * batch_size))
    print("Train step: ", train_step)
    print("Val step: ", val_step)
    if fast: train_step, val_step = 8, 8

    return train_loader, val_loader, train_step, val_step, val_imgs



def omnidata_dataset_train(taskonomy_variant="fullplus",image_size=512,normalize_rgb=False, tasks=None):
    # tasks = ['rgb', 'normal', 'mask_valid']        

    opt_train_taskonomy = TaskonomyReplicaGsoDataset.Options(
        tasks=tasks,
        datasets=['taskonomy'],
        split='train',
        taskonomy_variant=taskonomy_variant,
        transform='DEFAULT',
        image_size=image_size,
        normalize_rgb=normalize_rgb,
        randomize_views=True
    )

    trainset_taskonomy = TaskonomyReplicaGsoDataset(options=opt_train_taskonomy)

    opt_train_replica = TaskonomyReplicaGsoDataset.Options(
        tasks=tasks,
        datasets=['replica'],
        split='train',
        taskonomy_variant=taskonomy_variant,
        transform='DEFAULT',
        image_size=image_size,
        normalize_rgb=normalize_rgb,
        randomize_views=True
    )

    trainset_replica = TaskonomyReplicaGsoDataset(options=opt_train_replica)

    opt_train_hypersim = TaskonomyReplicaGsoDataset.Options(
        tasks=tasks,
        datasets=['hypersim'],
        split='train',
        taskonomy_variant=taskonomy_variant,
        transform='DEFAULT',
        image_size=image_size,
        normalize_rgb=normalize_rgb,
        randomize_views=True
    )

    trainset_hypersim = TaskonomyReplicaGsoDataset(options=opt_train_hypersim)

    opt_train_gso = TaskonomyReplicaGsoDataset.Options(
        tasks=tasks,
        datasets=['replica_gso'],
        split='train',
        taskonomy_variant=taskonomy_variant,
        transform='DEFAULT',
        image_size=image_size,
        normalize_rgb=normalize_rgb,
        randomize_views=True
    )

    trainset_gso = TaskonomyReplicaGsoDataset(options=opt_train_gso)

    opt_train_blendedMVS = TaskonomyReplicaGsoDataset.Options(
        tasks=tasks,
        datasets=['blended_mvg'],
        split='train',
        taskonomy_variant=taskonomy_variant,
        transform='DEFAULT',
        image_size=image_size,
        normalize_rgb=normalize_rgb,
        randomize_views=True
    )

    trainset_blendedMVS = TaskonomyReplicaGsoDataset(options=opt_train_blendedMVS)

    taskonomy_count = len(trainset_taskonomy)
    replica_count = len(trainset_replica)
    hypersim_count = len(trainset_hypersim)
    gso_count = len(trainset_gso)
    blendedMVS_count = len(trainset_blendedMVS)

    dataset_sample_count = torch.tensor([taskonomy_count, replica_count, hypersim_count, gso_count, blendedMVS_count])
    weight = 1. / dataset_sample_count.float()

    print("dataset weight ", weight)
    print("sample count ", dataset_sample_count)
    samples_weight = torch.tensor(
        [weight[0]] * taskonomy_count + [weight[1]] * replica_count + [weight[2]] * hypersim_count + [weight[3]] * gso_count + [weight[4] * 3] * blendedMVS_count)
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    trainset = ConcatDataset(
        [trainset_taskonomy, trainset_replica, trainset_hypersim, trainset_gso, trainset_blendedMVS])

    return trainset

def omnidata_dataset_val(taskonomy_variant="fullplus",image_size=512,normalize_rgb=False, tasks=None):
    # tasks = ['rgb', 'normal', 'mask_valid']   

    opt_val_taskonomy = TaskonomyReplicaGsoDataset.Options(
    split='val',
    taskonomy_variant=taskonomy_variant,
    tasks=tasks,
    datasets=['taskonomy'],
    transform='DEFAULT',
    image_size=image_size,
    normalize_rgb=normalize_rgb,
    randomize_views=False
    )

    valset_taskonomy = TaskonomyReplicaGsoDataset(options=opt_val_taskonomy)
    valset_taskonomy.randomize_order(seed=99)

    opt_val_replica = TaskonomyReplicaGsoDataset.Options(
        split='val',
        taskonomy_variant=taskonomy_variant,
        tasks=tasks,
        datasets=['replica'],
        transform='DEFAULT',
        image_size=image_size,
        normalize_rgb=normalize_rgb,
        randomize_views=False
    )

    valset_replica = TaskonomyReplicaGsoDataset(options=opt_val_replica)
    valset_replica.randomize_order(seed=99)

    opt_val_hypersim = TaskonomyReplicaGsoDataset.Options(
        split='val',
        taskonomy_variant=taskonomy_variant,
        tasks=tasks,
        datasets=['hypersim'],
        transform='DEFAULT',
        image_size=image_size,
        normalize_rgb=normalize_rgb,
        randomize_views=False
    )

    valset_hypersim = TaskonomyReplicaGsoDataset(options=opt_val_hypersim)
    valset_hypersim.randomize_order(seed=99)

    opt_val_gso = TaskonomyReplicaGsoDataset.Options(
        split='val',
        taskonomy_variant=taskonomy_variant,
        tasks=tasks,
        datasets=['replica_gso'],
        transform='DEFAULT',
        image_size=image_size,
        normalize_rgb=normalize_rgb,
        randomize_views=False
    )

    valset_gso = TaskonomyReplicaGsoDataset(options=opt_val_gso)
    valset_gso.randomize_order(seed=99)

    opt_val_blendedMVS = TaskonomyReplicaGsoDataset.Options(
        split='val',
        taskonomy_variant=taskonomy_variant,
        tasks=tasks,
        datasets=['blended_mvg'],
        transform='DEFAULT',
        image_size=image_size,
        normalize_rgb=normalize_rgb,
        randomize_views=False
    )

    valset_blendedMVS = TaskonomyReplicaGsoDataset(options=opt_val_blendedMVS)
    valset_blendedMVS.randomize_order(seed=99)

    taskonomy_count = len(valset_taskonomy)
    replica_count = len(valset_replica)
    hypersim_count = len(valset_hypersim)
    gso_count = len(valset_gso)
    blendedMVS_count = len(valset_blendedMVS)

    dataset_sample_count = torch.tensor([taskonomy_count, replica_count, hypersim_count, gso_count, blendedMVS_count])
    weight = 1. / dataset_sample_count.float()

    print("dataset weight ", weight)
    print("sample count ", dataset_sample_count)

    samples_weight = torch.tensor(
        [weight[0]] * taskonomy_count + [weight[1]] * replica_count + [weight[2]] * hypersim_count + [weight[3]] * gso_count + [weight[4] * 3] * blendedMVS_count)
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    valset = ConcatDataset(
        [valset_taskonomy, valset_replica, valset_hypersim, valset_gso, valset_blendedMVS])


    def select_val_samples_for_datasets():
        frls = 0
        val_imgs = defaultdict(list)

        while len(val_imgs['hypersim']) < 20:
            idx = random.randint(0, len(valset_hypersim) - 1)
            val_imgs['hypersim'].append(valset_hypersim[idx])
        while len(val_imgs['replica']) < 20:
            idx = random.randint(0, len(valset_replica) - 1)
            example = valset_replica[idx]
            # building = example['positive']['building']
            # if building.startswith('frl') and frls > 15:
            #     continue
            # if building.startswith('frl'): frls += 1
            val_imgs['replica'].append(valset_replica[idx])
        while len(val_imgs['taskonomy']) < 20:
            idx = random.randint(0, len(valset_taskonomy) - 1)
            val_imgs['taskonomy'].append(valset_taskonomy[idx])
        while len(val_imgs['replica_gso']) < 20:
            idx = random.randint(0, len(valset_gso) - 1)
            val_imgs['replica_gso'].append(valset_gso[idx])
        while len(val_imgs['blended_mvg']) < 20:
            idx = random.randint(0, len(valset_blendedMVS) - 1)
            val_imgs['blended_mvg'].append(valset_blendedMVS[idx])
        # breakpoint()
        return val_imgs

    return valset, select_val_samples_for_datasets()



""" Load all buildings """
def load_all(tasks, buildings=None, batch_size=64, split_file="data/split.txt", batch_transforms=cycle):

    data = yaml.load(open(split_file))
    buildings = buildings or (data["train_buildings"] + data["val_buildings"])

    data_loader = torch.utils.data.DataLoader(
        TaskDataset(buildings=buildings, tasks=tasks),
        batch_size=batch_size,
        num_workers=0, shuffle=True, pin_memory=True
    )

    return data_loader


def load_test(all_tasks, buildings=["almena", "albertville", "espanola"], sample=4):

    all_tasks = [get_task(t) if isinstance(t, str) else t for t in all_tasks]
    print(f"number of images in {buildings[0]}:")
    test_loader1 = torch.utils.data.DataLoader(
        TaskDataset(buildings=[buildings[0]], tasks=all_tasks, shuffle=False),
        batch_size=sample,
        num_workers=6, shuffle=False, pin_memory=True,
    )
    print(f"number of images in {buildings[1]}:")
    test_loader2 = torch.utils.data.DataLoader(
        TaskDataset(buildings=[buildings[1]], tasks=all_tasks, shuffle=False),
        batch_size=sample,
        num_workers=6, shuffle=False, pin_memory=True,
    )
    print(f"number of images in {buildings[2]}:")
    test_loader3 = torch.utils.data.DataLoader(
        TaskDataset(buildings=[buildings[2]], tasks=all_tasks, shuffle=False),
        batch_size=sample,
        num_workers=6, shuffle=False, pin_memory=True,
    )
    set1 = list(itertools.islice(test_loader1, 1))[0]
    set2 = list(itertools.islice(test_loader2, 1))[0]
    set3 = list(itertools.islice(test_loader3, 1))[0]
    test_set = tuple(torch.cat([x, y, z], dim=0) for x, y, z in zip(set1, set2, set3))
    return test_set


def load_ood(tasks=[tasks.rgb], ood_path=OOD_DIR, sample=21):
    ood_loader = torch.utils.data.DataLoader(
        ImageDataset(tasks=tasks, data_dir=ood_path),
        batch_size=sample,
        num_workers=sample, shuffle=False, pin_memory=True
    )
    ood_images = list(itertools.islice(ood_loader, 1))[0]
    return ood_images



class TaskDataset(Dataset):

    def __init__(self, buildings, tasks=[get_task("rgb"), get_task("normal")], data_dirs=DATA_DIRS,
            building_files=None, convert_path=None, use_raid=USE_RAID, resize=None, unpaired=False, shuffle=True):

        super().__init__()
        self.buildings, self.tasks, self.data_dirs = buildings, tasks, data_dirs
        self.building_files = building_files or self.building_files
        self.convert_path = convert_path or self.convert_path
        self.resize = resize
        if use_raid:
            self.convert_path = self.convert_path_raid
            self.building_files = self.building_files_raid

        self.file_map = {}
        for data_dir in self.data_dirs:
            for file in glob.glob(f'{data_dir}/*'):
                res = parse.parse("{building}_{task}", file[len(data_dir)+1:])
                if res is None: continue
                self.file_map[file[len(data_dir)+1:]] = data_dir

        filtered_files = None
        task = tasks[0]
        task_files = []
        for building in buildings:
            task_files += self.building_files(task, building)
        print(f"    {task.name} file len: {len(task_files)}")
        self.idx_files = task_files
        if not shuffle: self.idx_files = sorted(task_files)

        print ("    Intersection files len: ", len(self.idx_files))

    def reset_unpaired(self):
        if self.unpaired:
            self.task_indices = {task:random.sample(range(len(self.idx_files)), len(self.idx_files)) for task in self.task_indices}

    def building_files(self, task, building):
        """ Gets all the tasks in a given building (grouping of data) """
        return get_files(f"{building}_{task.file_name}/{task.file_name}/*.{task.file_ext}", self.data_dirs)

    def building_files_raid(self, task, building):
        return get_files(f"{task.file_name}/{building}/*.{task.file_ext}", self.data_dirs)

    def convert_path(self, source_file, task):
        """ Converts a file from task A to task B. Can be overriden by subclasses"""
        source_file = "/".join(source_file.split('/')[-3:])
        result = parse.parse("{building}_{task}/{task}/{view}_domain_{task2}.{ext}", source_file)
        building, _, view = (result["building"], result["task"], result["view"])
        dest_file = f"{building}_{task.file_name}/{task.file_name}/{view}_domain_{task.file_name_alt}.{task.file_ext}"
        if f"{building}_{task.file_name}" not in self.file_map:
            print (f"{building}_{task.file_name} not in file map")
            # IPython.embed()
            return ""
        data_dir = self.file_map[f"{building}_{task.file_name}"]
        return f"{data_dir}/{dest_file}"

    def convert_path_raid(self, full_file, task):
        """ Converts a file from task A to task B. Can be overriden by subclasses"""
        source_file = "/".join(full_file.split('/')[-3:])
        result = parse.parse("{task}/{building}/{view}.{ext}", source_file)
        building, _, view = (result["building"], result["task"], result["view"])
        dest_file = f"{task.file_name}/{building}/{view}.{task.file_ext}"
        return f"{full_file[:-len(source_file)-1]}/{dest_file}"

    def __len__(self):
        return len(self.idx_files)

    def __getitem__(self, idx):

        for i in range(200):
            try:
                res = []

                seed = random.randint(0, 1e10)

                for task in self.tasks:
                    file_name = self.convert_path(self.idx_files[idx], task)
                    if len(file_name) == 0: raise Exception("unable to convert file")
                    image = task.file_loader(file_name, resize=self.resize, seed=seed)

                    res.append(image)
                return tuple(res)
            except Exception as e:
                idx = random.randrange(0, len(self.idx_files))
                if i == 199: raise (e)


# class TrainTaskDataset(TaskDataset):

#     def __getitem__(self, idx):

#         for i in range(200):
#             try:
#                 res = []

#                 seed = random.randint(0, 1e10)
#                 crop = random.randint(int(0.7*512), 512) if bool(random.getrandbits(1)) else 512
                
#                 for task in self.tasks:
#                     jitter = random.random()<0.2 if task.name == 'rgb' else False
#                     distortion_ind = random.randint(0,1) if task.name == 'rgb' and (not jitter and random.random()<0.3) else None
#                     blur_noise_jpeg = [None]*2
#                     # samples_blur_noise_jpeg = [random.uniform(1,10), random.uniform(0.05,0.5), random.randint(5,30)]
#                     samples_blur_noise_jpeg = [random.uniform(1,10), random.uniform(0.05,0.5)]
#                     if distortion_ind is not None: blur_noise_jpeg[distortion_ind] = samples_blur_noise_jpeg[distortion_ind]
#                     file_name = self.convert_path(self.idx_files[idx], task)
#                     if len(file_name) == 0: raise Exception("unable to convert file")
#                     image = task.file_loader(file_name, resize=self.resize, seed=seed, crop=crop, jitter=jitter, blur_radius=blur_noise_jpeg[0], noise=blur_noise_jpeg[1], jpeg=None)
#                     res.append(image)

#                 return tuple(res)
#             except Exception as e:
#                 idx = random.randrange(0, len(self.idx_files))
#                 if i == 199: raise (e)

class TrainTaskDataset(TaskDataset):

    def __getitem__(self, idx):

        for i in range(200):
            try:
                res = []

                seed = random.randint(0, 1e10)
                crop = random.randint(int(0.7*512), 512) if bool(random.getrandbits(1)) else 512

                for task in self.tasks:
                    jitter = bool(random.getrandbits(1)) if task.name == 'rgb' else False
                    file_name = self.convert_path(self.idx_files[idx], task)
                    if len(file_name) == 0: raise Exception("unable to convert file")
                    image = task.file_loader(file_name, resize=self.resize, seed=seed, crop=crop, jitter=jitter)
                    res.append(image)

                return tuple(res)
            except Exception as e:
                idx = random.randrange(0, len(self.idx_files))
                if i == 199: raise (e)


class ImageDataset(Dataset):

    def __init__(
        self,
        tasks=[tasks.rgb],
        data_dir=f"data/ood_images",
        files=None,
    ):

        self.tasks = tasks
        #if not USE_RAID and files is None:
        #    os.system(f"ls {data_dir}/*.png")
        #    os.system(f"ls {data_dir}/*.png")

        self.files = files \
            or sorted(
                glob.glob(f"{data_dir}/*.png")
                + glob.glob(f"{data_dir}/*.jpg")
                + glob.glob(f"{data_dir}/*.jpeg")
            )

        print("number of ood images: ", len(self.files))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):

        file = self.files[idx]
        res = []
        seed = random.randint(0, 1e10)
        for task in self.tasks:
            image = task.file_loader(file, seed=seed)
            if image.shape[0] == 1: image = image.expand(3, -1, -1)
            res.append(image)
        return tuple(res)




if __name__ == "__main__":

    logger = VisdomLogger("data", env=JOB)
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        [tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.rgb(size=512)],
        batch_size=32,
    )
    print ("created dataset")
    logger.add_hook(lambda logger, data: logger.step(), freq=32)

    for i, _ in enumerate(train_dataset):
        logger.update("epoch", i)
