from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from datetime import datetime
import random, torch, torchvision
from utils.fastmri import singleCoilFastMRIMultiSliceDataset
import numpy as np
import h5py, os, dival, tomosipo, tqdm

from dival.datasets.dataset import Dataset as divalDataset
from odl.discr import nonuniform_partition
from odl.tomo.geometry import Parallel2dGeometry
from odl.tomo.operators import RayTransform


def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def load_data(args):
    timeStamp = datetime.now().strftime("%Y%m%d-%H%M")
    args.path = args.path + args.file_name
    data_path = args.data_path# "../../"

    if args.dataset == 'MRI':
        mri_path = os.path.join(data_path,'singlecoil_train/')
        save_path = args.path + 'MRI_' + timeStamp
        load_path = args.path + 'MRI.t7'
        tr_size, total_data = 973, 973
        ind_tr = random.sample(range(total_data), k=tr_size)
        ind_ts = np.array([956])
        DATA_tr = singleCoilFastMRIMultiSliceDataset(mri_path, data_indices=ind_tr)
        DATA_ts = singleCoilFastMRIMultiSliceDataset(mri_path, data_indices=ind_ts)

        train_loader = DataLoader(DATA_tr, batch_size=args.batch_size, shuffle=True)
        test_loader = DataLoader(DATA_ts, batch_size=args.batch_size_val, shuffle=False)
        return train_loader, test_loader, load_path, save_path, len(DATA_tr), len(DATA_ts)

    elif args.dataset == "CelebA":
        # Downloaded from https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip
        celebA_path = os.path.join(data_path,'CelebA/')
        save_path = args.path + 'CelbeA_' + timeStamp
        transform = transforms.Compose([transforms.ToTensor()])  # 202599
        CelebA_dataset = datasets.ImageFolder(celebA_path, transform=transform)  # image size: 218 x 178
        full_length = len(CelebA_dataset)
        train_index = range(int(full_length * 0.8))
        val_index = range(int(full_length * 0.8), int(full_length * 0.9))
        test_index = range(int(full_length * 0.9), full_length)

        train_set = torch.utils.data.Subset(CelebA_dataset, train_index)
        val_set = torch.utils.data.Subset(CelebA_dataset, val_index)
        test_set = torch.utils.data.Subset(CelebA_dataset, test_index)

        tr_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, drop_last=False)
        val_loader = DataLoader(dataset=val_set, batch_size=args.batch_size_val, shuffle=False, drop_last=True)
        ts_loader = DataLoader(dataset=test_set, batch_size=args.batch_size_val, shuffle=False, drop_last=False)
        return tr_loader, len(train_set), val_loader, len(val_set), ts_loader, len(test_set), save_path

    elif args.dataset == "CT":
        save_path = args.path + 'CT_' + timeStamp
        IMPL = 'astra_cuda'
        dataset = dival.get_standard_dataset('lodopab')
        dataset = AngleSubsetDataset(dataset, list(np.arange(0, 1000, 5)), IMPL)
        ray_trafo = dataset.ray_trafo
        train_data = dataset.create_torch_dataset("train")
        val_data = dataset.create_torch_dataset("validation")

        train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=False, drop_last=True)
        val_loader = DataLoader(dataset=val_data, batch_size=args.batch_size_val, shuffle=False, drop_last=True)
        return train_loader, len(train_data), val_loader, len(val_data), ray_trafo, save_path


def load_val_data(args):
    data_path = args.data_path
    if args.dataset == 'MRI':
        mri_path = os.path.join(data_path, "singlecoil_val/")
        timeStamp = datetime.now().strftime("%Y-%m-%d-%H%M")
        args.path = args.path + args.file_name
        save_path = args.path + 'MRI_val_' + timeStamp
        DATA_val = singleCoilFastMRIMultiSliceDataset(mri_path)
        val_loader = DataLoader(DATA_val, batch_size=args.batch_size_val, shuffle=False, drop_last=True)
        return val_loader, save_path, len(DATA_val)
    elif args.dataset == 'CT':
        timeStamp = datetime.now().strftime("%Y-%m-%d-%H%M")
        save_path = args.path + args.file_name + 'CT_val_' + timeStamp
        IMPL = 'astra_cuda'
        dataset = dival.get_standard_dataset('lodopab')
        dataset = AngleSubsetDataset(dataset, list(np.arange(0, 1000, 5)), IMPL)
        ray_trafo = dataset.ray_trafo
        DATA_val = dataset.create_torch_dataset("test")
        val_loader = DataLoader(dataset=DATA_val, batch_size=args.batch_size_val, shuffle=False, drop_last=True)
        return val_loader, save_path, len(DATA_val), ray_trafo


class AngleSubsetDataset(divalDataset):
    """
    CT Code borrowed from:
    https://github.com/oterobaguer/dip-ct-benchmark/blob/0539c284c94089ed86421ea0892cd68aa1d0575a/dliplib/utils/helper.py#L220
    """

    def __init__(self, dataset, angle_indices, impl=None):
        """
        Parameters
        ----------
        dataset : `Dataset`
            Basis CT dataset.
            Requirements:
                - sample elements are ``(observation, ground_truth)``
                - :meth:`get_ray_trafo` gives corresponding ray transform.
        angle_indices : array-like or slice
            Indices of the angles to use from the observations.
        impl : {``'skimage'``, ``'astra_cpu'``, ``'astra_cuda'``},\
                optional
            Implementation passed to :class:`odl.tomo.RayTransform` to
            construct :attr:`ray_trafo`.
        """
        self.dataset = dataset
        self.angle_indices = (angle_indices if isinstance(angle_indices, slice)
                              else np.asarray(angle_indices))
        self.train_len = self.dataset.get_len('train')
        self.validation_len = self.dataset.get_len('validation')
        self.test_len = self.dataset.get_len('test')
        self.random_access = self.dataset.supports_random_access()
        self.num_elements_per_sample = (
            self.dataset.get_num_elements_per_sample())
        orig_geometry = self.dataset.get_ray_trafo(impl=impl).geometry
        apart = nonuniform_partition(
            orig_geometry.angles[self.angle_indices])
        self.geometry = Parallel2dGeometry(
            apart=apart, dpart=orig_geometry.det_partition)
        orig_shape = self.dataset.get_shape()
        self.shape = ((apart.shape[0], orig_shape[0][1]), orig_shape[1])
        self.space = (None, self.dataset.space[1])  # preliminary, needed for
        # call to get_ray_trafo
        self.ray_trafo = self.get_ray_trafo(impl=impl)
        super().__init__(space=(self.ray_trafo.range, self.dataset.space[1]))

    def get_ray_trafo(self, **kwargs):
        """
        Return the ray transform that matches the subset of angles specified to
        the constructor via `angle_indices`.
        """
        return RayTransform(self.space[1], self.geometry, **kwargs)

    def generator(self, part='train'):
        for (obs, gt) in self.dataset.generator(part=part):
            yield (self.space[0].element(obs[self.angle_indices]), gt)

    def get_sample(self, index, part='train', out=None):
        if out is None:
            out = (True, True)
        (out_obs, out_gt) = out
        out_basis = (out_obs is not False, out_gt)
        obs_basis, gt = self.dataset.get_sample(index, part=part,
                                                out=out_basis)
        if isinstance(out_obs, bool):
            obs = (self.space[0].element(obs_basis[self.angle_indices])
                   if out_obs else None)
        else:
            out_obs[:] = obs_basis[self.angle_indices]
            obs = out_obs
        return (obs, gt)

    def get_samples(self, key, part='train', out=None):
        if out is None:
            out = (True, True)
        (out_obs, out_gt) = out
        out_basis = (out_obs is not False, out_gt)
        obs_arr_basis, gt_arr = self.dataset.get_samples(key, part=part,
                                                         out=out_basis)
        if isinstance(out_obs, bool):
            obs_arr = obs_arr_basis[:, self.angle_indices] if out_obs else None
        else:
            out_obs[:] = obs_arr_basis[:, self.angle_indices]
            obs_arr = out_obs
        return (obs_arr, gt_arr)
