
import numpy as np
import matplotlib as mpl

import os, sys, math, random, tarfile, glob, time, itertools
import parse

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

from utils import *
# from logger import Logger, VisdomLogger
from task_configs import get_task, tasks

from PIL import Image
from io import BytesIO
from sklearn.model_selection import train_test_split
import IPython


""" Default data loading configurations for training, validation, and testing. """
def load_cc(distortion='gaussian_noise', severity=1, target_task="normal",shuffle=True):
    ''' return rgb, mask_valid, and specified target task '''
    
    # train_tasks = [get_task(t) if isinstance(t, str) else t for t in train_tasks]
    split_file="config/split.txt"
    data = yaml.safe_load(open(split_file))
    train_buildings = ["almena","albertville"]
    # train_buildings = data["val_buildings"]
    #val_buildings = val_buildings or (["almena"] if fast else data["val_buildings"])
    print("number of train images:")
    train_loader = TaskDataset(buildings=train_buildings, tasks=[get_task("rgb"),get_task("mask_valid"), get_task(target_task)],
                                distortion=distortion, severity=severity,shuffle=shuffle)
    #print("number of val images:")
    #val_loader = dataset_cls(buildings=val_buildings, tasks=val_tasks)

    # if subset_size is not None or subset is not None:
    #     train_loader = torch.utils.data.Subset(train_loader,
    #         random.sample(range(len(train_loader)), subset_size or int(len(train_loader)*subset)),
    #     )

    return train_loader

""" Default data loading configurations for training, validation, and testing. """
def load_3dcc(distortion='motion_blur_3d_v2', severity=1, target_task="normal",shuffle=True):
    ''' return rgb, mask_valid, and specified target task '''
    
    # train_tasks = [get_task(t) if isinstance(t, str) else t for t in train_tasks]
    split_file="config/split.txt"
    data = yaml.safe_load(open(split_file))
    train_buildings = ["almena","albertville"]
    # train_buildings = data["val_buildings"]
    #val_buildings = val_buildings or (["almena"] if fast else data["val_buildings"])
    print("number of train images:")
    train_loader = CC3DDataset(buildings=train_buildings, tasks=[get_task("rgb"),get_task("mask_valid"), get_task(target_task)],
                                distortion=distortion, severity=severity,shuffle=shuffle)

    return train_loader


class TaskDataset(Dataset):

    def __init__(self, buildings, tasks=[get_task("rgb"), get_task("normal")],data_dirs=DATA_DIRS,
            distortion='gaussian_noise',severity=1,
            building_files=None, convert_path=None, resize=None, unpaired=False, shuffle=True):

        super().__init__()
        self.buildings, self.tasks, self.data_dirs = buildings, tasks, data_dirs
        self.building_files = building_files or self.building_files
        self.convert_path = convert_path or self.convert_path
        self.resize = resize

        self.severity = severity
        self.distortionname = distortion
        
        self.file_map = {}
        for data_dir in self.data_dirs:
            for file in glob.glob(f'{data_dir}/*'):
                res = parse.parse("{building}_{task}", file[len(data_dir)+1:])
                if res is None: continue
                self.file_map[file[len(data_dir)+1:]] = data_dir

        filtered_files = None
        task = tasks[0]
        if task.name == 'mask_valid': task = tasks[1]
        task_files = []
        for building in buildings:
            task_files += self.building_files(task, building)
        print(f"    {task.name} file len: {len(task_files)}")
        self.idx_files = task_files
        if not shuffle: self.idx_files = sorted(task_files)

        print ("    Intersection files len: ", len(self.idx_files))

    def reset_unpaired(self):
        if self.unpaired:
            self.task_indices = {task:random.sample(range(len(self.idx_files)), len(self.idx_files)) for task in self.task_indices}

    def building_files(self, task, building):
        """ Gets all the tasks in a given building (grouping of data) """
        return get_files(f"{building}_{task.file_name}/{task.file_name}/*.{task.file_ext}", self.data_dirs)

    def building_files_raid(self, task, building):
        return get_files(f"{task.file_name}/{building}/*.{task.file_ext}", self.data_dirs)

    def convert_path(self, source_file, task):
        """ Converts a file from task A to task B. Can be overriden by subclasses"""
        source_file = "/".join(source_file.split('/')[-3:])
        result = parse.parse("{building}_{task}/{task}/{view}_domain_{task2}.{ext}", source_file)
        building, _, view = (result["building"], result["task"], result["view"])
        if task.file_name == 'mask_valid':
            dest_file = f"{building}_{task.file_name}/{view}_domain_{task.file_name_alt}.{task.file_ext}"
        else:
            dest_file = f"{building}_{task.file_name}/{task.file_name}/{view}_domain_{task.file_name_alt}.{task.file_ext}"
        if f"{building}_{task.file_name}" not in self.file_map:
            print (f"{building}_{task.file_name} not in file map")
            # IPython.embed()
            return ""
        data_dir = self.file_map[f"{building}_{task.file_name}"]
        return f"{data_dir}/{dest_file}"

    def convert_path_masks(self, source_file, task):
        """ Converts a file from task A to task B. Can be overriden by subclasses"""
        source_file = "/".join(source_file.split('/')[-3:])
        result = parse.parse("{building}_{task}/{task}/{view}_domain_{task2}.{ext}", source_file)
        building, _, view = (result["building"], result["task"], result["view"])
        dest_file = f"{building}_{task.file_name_alt}/{view}_domain_depth_zbuffer.{task.file_ext}"
        return f"/datasets/taskonomymask/{dest_file}"

    def convert_path_cc(self, source_file, task):
        """ Converts a file from task A to task B. Can be overriden by subclasses"""
        source_file = "/".join(source_file.split('/')[-3:])
        result = parse.parse("{building}_{task}/{task}/{view}_domain_{task2}.{ext}", source_file)
        building, _, view = (result["building"], result["task"], result["view"])
        dest_file = f"{self.distortionname}/{self.severity}/{building}/{view}_domain_{task.file_name_alt}.{task.file_ext}"
        # dest_file = f"{building}_{task.file_name}/{task.file_name}/{view}_domain_{task.file_name_alt}.{task.file_ext}"
        if f"{building}_{task.file_name}" not in self.file_map:
            print (f"{building}_{task.file_name} not in file map")
            # IPython.embed()
            return ""
        data_dir = '/datasets/taskonomy-2d3dcc/2dcc'
        # data_dir = self.file_map[f"{building}_{task.file_name}"]
        return f"{data_dir}/{dest_file}"

    def convert_path_3dcc(self, source_file, task):
        """ Converts a file from task A to task B. Can be overriden by subclasses"""
        source_file = "/".join(source_file.split('/')[-3:])
        result = parse.parse("{building}_{task}/{task}/{view}_domain_{task2}.{ext}", source_file)
        building, _, view = (result["building"], result["task"], result["view"])
        dest_file = f"{self.distortionname}/{self.severity}/{building}/{view}_domain_{task.file_name_alt}.{task.file_ext}"
        # dest_file = f"{building}_{task.file_name}/{task.file_name}/{view}_domain_{task.file_name_alt}.{task.file_ext}"
        if f"{building}_{task.file_name}" not in self.file_map:
            print (f"{building}_{task.file_name} not in file map")
            # IPython.embed()
            return ""
        data_dir = '/datasets/taskonomy-2d3dcc/3dcc'
        # data_dir = self.file_map[f"{building}_{task.file_name}"]
        return f"{data_dir}/{dest_file}"

    def __len__(self):
        return len(self.idx_files)

    def __getitem__(self, idx):

        for i in range(200):
            try:
                res = []

                seed = random.randint(0, 1e10)

                for task in self.tasks:
                    if task.name=='rgb' and self.severity is not None:
                        file_name = self.convert_path_cc(self.idx_files[idx], task)
                    else:
                        if task.name == 'mask_valid':
                            file_name = self.convert_path_masks(self.idx_files[idx], task)
                        else:
                            file_name = self.convert_path(self.idx_files[idx], task)
                    if len(file_name) == 0: raise Exception("unable to convert file")
                    image = task.file_loader(file_name, resize=self.resize, seed=seed)

                    res.append(image)
                return tuple(res)
            except Exception as e:
                idx = random.randrange(0, len(self.idx_files))
                if i == 199: raise (e)


class CC3DDataset(TaskDataset):

    def __getitem__(self, idx):

        for i in range(200):
            try:
                res = []

                seed = random.randint(0, 1e10)

                for task in self.tasks:
                    if task.name=='rgb' and self.severity is not None:
                        file_name = self.convert_path_3dcc(self.idx_files[idx], task)
                    else:
                        if task.name == 'mask_valid':
                            file_name = self.convert_path_masks(self.idx_files[idx], task)
                        else:
                            file_name = self.convert_path(self.idx_files[idx], task)
                    if len(file_name) == 0: raise Exception("unable to convert file")
                    image = task.file_loader(file_name, resize=self.resize, seed=seed)

                    res.append(image)
                return tuple(res)
            except Exception as e:
                idx = random.randrange(0, len(self.idx_files))
                if i == 199: raise (e)


if __name__ == "__main__":

    logger = VisdomLogger("data", env=JOB)
    train_dataset, val_dataset, train_step, val_step = load_train_val(
        [tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.rgb(size=512)],
        batch_size=32,
    )
    print ("created dataset")
    logger.add_hook(lambda logger, data: logger.step(), freq=32)

    for i, _ in enumerate(train_dataset):
        logger.update("epoch", i)
