"""
Data module that draws from a never ending supply of GP data.
"""
import os
import random
from typing import Optional, Tuple

import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import DataLoader, TensorDataset

from krt import KRT_PATH


class RBFGPIterator:

    def __init__(
        self,
        dim_x: int,
        batches_per_epoch: int,
        batch_size: int,
        min_points_per_function: int,
        max_points_per_function: int,
        lengthscale_range: Tuple[float, float],
        scale_range: Tuple[float, float],
        x_bounds: Tuple[float, float],
        noise: float,
    ):
        """Constructor.

        Args:
            dim_x: The size of the x dimension.
            batches_per_epoch: How many batches to feed to the network per epoch.
            batch_size: Number of functions per batch.
            points_per_function: Number of points to draw per function.
            lengthscale_range: Range of the uniform distribution to draw from the
                lengthscales.
            scale_range: Range of the uniform distribution to draw from the noise.
            x_bounds: The bounds for drawing x points.
            te_data_path: Data to load in for testing.
        """
        self.dim_x = dim_x
        self.batches_per_epoch = batches_per_epoch
        self.batch_size = batch_size
        self.min_points_per_function = min_points_per_function
        self.max_points_per_function = max_points_per_function
        self.lscale_diam = lengthscale_range[1] - lengthscale_range[0]
        self.lscale_low = lengthscale_range[0]
        self.scale_diam = scale_range[1] - scale_range[0]
        self.scale_low = scale_range[0]
        self.x_diam = x_bounds[1] - x_bounds[0]
        self.x_low = x_bounds[0]
        self.noise = noise
        self.counter = 0
        self.num_val = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.counter >= self.batches_per_epoch:
            self.counter = 0
            raise StopIteration
        self.counter += 1
        points_per_function = random.randint(self.min_points_per_function,
                                             self.max_points_per_function)
        x_sample = (torch.rand(self.batch_size, points_per_function, self.dim_x)
                    * self.x_diam + self.x_low)
        lscales = (torch.rand(self.batch_size, 1, self.dim_x)
                   * self.lscale_diam + self.lscale_low)
        scales = torch.rand(self.batch_size, 1, 1) * self.scale_diam + self.scale_low
        dists = torch.cdist(x_sample / lscales, x_sample / lscales,
                            compute_mode='donot_use_mm_for_euclid_dist')  # (B, L, L)
        cov = scales.pow(2) * (-0.5 * dists.pow(2)).exp()
        cov = cov + self.noise ** 2 * torch.eye(points_per_function).unsqueeze(0)
        mean = torch.zeros(self.batch_size, points_per_function)
        y_sample = MultivariateNormal(mean, cov).rsample().unsqueeze(-1)
        return x_sample, y_sample


class InfiniteGPData:

    def __init__(
        self,
        dim_x: int,
        batches_per_epoch: int,
        batch_size: int,
        min_points_per_function: int,
        max_points_per_function: int,
        kernel_type: str,
        lengthscale_range: Tuple[float, float],
        scale_range: Tuple[float, float],
        x_bounds: Tuple[float, float],
        noise: float,
        val_path: Optional[str] = None,
        te_path: Optional[str] = None,
        **kwargs
    ):
        """Constructor.

        Args:
            dim_x: The size of the x dimension.
            batches_per_epoch: How many batches to feed to the network per epoch.
            batch_size: Number of functions per batch.
            min_points_per_function: Minimum points per function.
            max_points_per_function: Minimum points per function.
            kernel_type: The type of kernel to use.
            lengthscale_range: Range of the uniform distribution to draw from the
                lengthscales.
            scale_range: Range of the uniform distribution to draw from the noise.
            x_bounds: The bounds for drawing x points.
            val_data_path: Data to load in for testing.
            te_data_path: Data to load in for testing.
        """
        self.dim_x = dim_x
        self.dim_y = 1
        self.batches_per_epoch = batches_per_epoch
        self.kernel_type = kernel_type
        # TODO: Add additional kernels in the future.
        assert kernel_type.lower() == 'rbf'
        self.train_data = RBFGPIterator(
            dim_x=dim_x,
            batches_per_epoch=batches_per_epoch,
            batch_size=batch_size,
            min_points_per_function=min_points_per_function,
            max_points_per_function=max_points_per_function,
            lengthscale_range=lengthscale_range,
            scale_range=scale_range,
            x_bounds=x_bounds,
            noise=noise
        )
        self.nu_tr = float('inf')
        # Possibly load in validation data.
        if val_path is not None:
            if 'x_data.pt' in os.listdir(os.path.join(KRT_PATH, val_path)):
                val_xname, val_yname = 'x_data.pt', 'y_data.pt'
            else:
                val_xname, val_yname = 'te_x_data.pt', 'te_y_data.pt'
            val_x_data = torch.load(os.path.join(KRT_PATH, val_path, val_xname))
            val_y_data = torch.load(os.path.join(KRT_PATH, val_path, val_yname))
            pin_memory = kwargs.get('pin_memory', True)
            self.val_data = DataLoader(
                TensorDataset(val_x_data, val_y_data),
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                pin_memory=pin_memory,
            )
            self.num_val = self.val_data.dataset.tensors[0].shape[0]
        else:
            self.val_data = None
            self.num_val = 0
        # Possibly load in test data.
        if te_path is not None:
            te_x_data = torch.load(os.path.join(KRT_PATH, te_path, 'te_x_data.pt'))
            te_y_data = torch.load(os.path.join(KRT_PATH, te_path, 'te_y_data.pt'))
            cjoint_ll = torch.load(os.path.join(KRT_PATH, te_path,
                                                'cum_joint_logprob.pt'))
            marginal_ll = torch.load(os.path.join(KRT_PATH, te_path,
                                                  'marginal_logprob.pt'))
            pin_memory = kwargs.get('pin_memory', True)
            self.test_data = DataLoader(
                TensorDataset(te_x_data, te_y_data, cjoint_ll, marginal_ll),
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                pin_memory=pin_memory,
            )
            self.num_te = self.test_data.dataset.tensors[0].shape[0]
        else:
            self.test_data = None
            self.num_te = 0
        self.L = max_points_per_function

    @property
    def train_num_batches(self):
        return self.batches_per_epoch

    @property
    def val_num_batches(self):
        return len(self.val_data) if self.val_data is not None else 0

    @property
    def te_num_batches(self):
        return 0 if self.test_data is None else len(self.test_data)
