from __future__ import annotations

import os  # type: ignore
from typing import Any, Optional, Tuple

import numpy as np  # type: ignore
import torch
from torch.utils.data import Dataset

T = torch.Tensor


class TabularDataset(Dataset):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super(TabularDataset, self).__init__()
        self.x: T
        self.y: T
        self.og_y: T

        # these are the parameters of the normalization
        self.mu: T
        self.sigma: T
        self.y_mu: T
        self.y_sigma: T

    def __getitem__(self, i: int) -> Tuple[T, T]:
        return self.x[i], self.y[i]

    def __len__(self) -> int:
        return self.x.size(0)

    def prune(self, idx: T) -> None:
        self.x = self.x[idx]
        self.y = self.y[idx]

    def get_y_moments(self) -> Tuple[float, float]:
        return self.y_mu.item(), self.y_sigma.item()

    def get_moments(self) -> Tuple[T, ...]:
        return self.mu, self.sigma, self.y_mu, self.y_sigma

    def get_x_moments(self) -> Tuple[T, ...]:
        return self.mu, self.sigma

    def set_name(self, name: str) -> None:
        self.name = name

    def __str__(self) -> str:
        return self.name

    def __repr__(self) -> str:
        return self.name

    def get_feature_ranges(self) -> T:
        """
        get the feature ranges for all of the x features, this was originally used
        to determine the ranges of features for generating adversarial examples as done by
        deep ensembles paper https://arxiv.org/abs/1612.01474
        """
        return torch.abs(self.x.min(dim=0)[0] - self.x.max(dim=0)[0])

    def valid(self) -> None:
        if torch.any(torch.isinf(self.x)) or torch.any(torch.isnan(self.x)):
            raise ValueError("x has invalid values")
        elif torch.any(torch.isinf(self.y)) or torch.any(torch.isnan(self.y)):
            raise ValueError("y has invalid values")

    def standard_normalize_y(self, y_mu: Optional[T] = None, y_sigma: Optional[T] = None,) -> Tuple[Optional[T], ...]:
        """standard normalize the dataset by ([x,y] - mu) / sigma"""
        if y_mu is None or y_sigma is None:
            self.y_mu = self.y.mean()
            self.y_sigma = self.y.std()
            self.y_sigma[self.y_sigma == 0] = 1

            self.og_y = torch.clone(self.y)
            self.y = (self.y - self.y_mu) / self.y_sigma

            self.valid()
            return self.y_mu, self.y_sigma

        self.y_mu = y_mu
        self.y_sigma = y_sigma

        self.og_y = torch.clone(self.y)
        self.y = (self.y - self.y_mu) / self.y_sigma

        self.valid()
        return y_mu, y_sigma

    def standard_normalize_x(self, mu: Optional[T] = None, sigma: Optional[T] = None) -> Tuple[Optional[T], ...]:
        """standard normalize the dataset by ([x,y] - mu) / sigma"""
        if mu is None or sigma is None:
            self.mu = self.x.mean(dim=0)
            self.sigma = self.x.std(dim=0)
            self.sigma[self.sigma == 0] = 1

            if torch.any(self.sigma == 0):
                raise ValueError("sigma should not have zero values, see what is going on here")
                self.sigma[self.sigma == 0] = 1

            self.x = (self.x - self.mu) / self.sigma
            self.valid()
            return self.mu, self.sigma

        self.mu = mu
        self.sigma = sigma
        self.x = (self.x - self.mu) / self.sigma
        self.valid()
        return mu, sigma


class PBPDataset(TabularDataset):
    def __init__(self, *args: Any, x: T = None, y: T = None, cluster_idx: T = None, name: str = "") -> None:
        """
        this is for the datasets from the MC Dropout repository which were first used
        in probabilistic backpropagation https://arxiv.org/abs/1502.05336
        """
        super(PBPDataset, self).__init__()

        if x is None or y is None or name == "":
            raise ValueError("kwargs needs to have x, y and name for PBP dataset")

        self.x = x
        self.y = y
        self.idx = torch.linspace(0, x.size(0) - 1, x.size(0))
        if cluster_idx is not None:
            self.cluster_idx = cluster_idx

        self.name = name

    def __str__(self) -> str:
        return self.name

    def __repr__(self) -> str:
        return self.name
