import abc

import numpy as np
import torch
from torch.distributions import Beta

from data_utils.data_corruption.data_corruption_masker import DataCorruptionMasker
from data_utils.data_scaler import DataScaler
from models.ClassificationModel import ClassProbabilities
from models.data_mask_estimators.DataMaskEstimator import DataMaskEstimator
from models.data_mask_estimators.OracleDataMasker import OracleDataMasker


class OracleDataMaskerWithDeltaMinMax(DataMaskEstimator, abc.ABC):

    def __init__(self, data_scaler: DataScaler, data_masker: DataCorruptionMasker, dataset_name: str, x_dim: int,
                 z_dim: int,
                 delta_min: float = 1, delta_max: float = 1):
        super().__init__(dataset_name, x_dim, z_dim)
        self.oracle_model = OracleDataMasker(data_scaler, data_masker, dataset_name, x_dim, z_dim)
        self.delta_min = delta_min
        self.delta_max = delta_max

    def forward(self, x, z) -> ClassProbabilities:
        return self.oracle_model.forward(x, z)

    @abc.abstractmethod
    def generate_noise(self, size: int, device):
        pass

    @property
    @abc.abstractmethod
    def display_name(self) -> str:
        pass


class OracleDataMaskerWithUniformDeltaMinMax(OracleDataMaskerWithDeltaMinMax):

    def __init__(self, data_scaler: DataScaler, data_masker: DataCorruptionMasker, dataset_name: str, x_dim: int,
                 z_dim: int,
                 delta_min: float = 1, delta_max: float = 1):
        super().__init__(data_scaler, data_masker, dataset_name, x_dim, z_dim, delta_min, delta_max)

    @property
    def name(self) -> str:
        return f"{self.base_name}oracle_with_uniform_delta_min={np.round(self.delta_min, 4)}_delta_max={np.round(self.delta_max, 4)}"

    def generate_noise(self, size: int, device):
        uniform_noise = torch.rand(size) * (self.delta_max - self.delta_min) + self.delta_min
        return uniform_noise.to(device)

    @property
    def display_name(self) -> str:
        return "Uniform"


class OracleDataMaskerWithLeftSidedDeltaMinMax(OracleDataMaskerWithDeltaMinMax):

    def __init__(self, data_scaler: DataScaler, data_masker: DataCorruptionMasker, dataset_name: str, x_dim: int,
                 z_dim: int,
                 delta_min: float = 1, delta_max: float = 1):
        super().__init__(data_scaler, data_masker, dataset_name, x_dim, z_dim, delta_min, delta_max)

    @property
    def name(self) -> str:
        return f"{self.base_name}oracle_with_left_sided_delta_min={np.round(self.delta_min, 4)}_delta_max={np.round(self.delta_max, 4)}"

    def generate_noise(self, size: int, device):
        uniform_noise = torch.rand(size) * (
                self.delta_max - self.delta_min) + self.delta_min
        min_extreme_val = 0.95 * self.delta_min + 0.05 * self.delta_max
        one_sided_noise = torch.rand(size) * (
                min_extreme_val - self.delta_min) + self.delta_min
        rnd_idx = torch.rand(size)
        noise = torch.zeros(size)
        noise[rnd_idx < 0.9] += one_sided_noise[rnd_idx < 0.9]
        noise[rnd_idx > 0.9] += uniform_noise[rnd_idx > 0.9]
        return noise.to(device)

    @property
    def display_name(self) -> str:
        return "Left sided"

class OracleDataMaskerWithRightSidedDeltaMinMax(OracleDataMaskerWithDeltaMinMax):

    def __init__(self, data_scaler: DataScaler, data_masker: DataCorruptionMasker, dataset_name: str, x_dim: int,
                 z_dim: int,
                 delta_min: float = 1, delta_max: float = 1):
        super().__init__(data_scaler, data_masker, dataset_name, x_dim, z_dim, delta_min, delta_max)

    @property
    def name(self) -> str:
        return f"{self.base_name}oracle_with_right_sided_delta_min={np.round(self.delta_min, 4)}_delta_max={np.round(self.delta_max, 4)}"

    def generate_noise(self, size: int, device):
        uniform_noise = torch.rand(size) * (
                self.delta_max - self.delta_min) + self.delta_min
        min_extreme_val = 0.95 * self.delta_max + 0.05 * self.delta_min
        one_sided_noise = torch.rand(size) * (self.delta_max - min_extreme_val) + min_extreme_val
        rnd_idx = torch.rand(size)
        noise = torch.zeros(size)
        noise[rnd_idx < 0.9] += one_sided_noise[rnd_idx < 0.9]
        noise[rnd_idx > 0.9] += uniform_noise[rnd_idx > 0.9]
        return noise.to(device)
    @property
    def display_name(self) -> str:
        return "Right sided"

class OracleDataMaskerWithExtremeTailsDeltaMinMax(OracleDataMaskerWithDeltaMinMax):

    def __init__(self, data_scaler: DataScaler, data_masker: DataCorruptionMasker, dataset_name: str, x_dim: int,
                 z_dim: int,
                 delta_min: float = 1, delta_max: float = 1):
        super().__init__(data_scaler, data_masker, dataset_name, x_dim, z_dim, delta_min, delta_max)

    @property
    def name(self) -> str:
        return f"{self.base_name}oracle_with_extreme_tails_delta_min={np.round(self.delta_min, 4)}_delta_max={np.round(self.delta_max, 4)}"

    def generate_noise(self, size: int, device):
        uniform_noise = torch.rand(size) * (
                self.delta_max - self.delta_min) + self.delta_min
        min_extreme_val1 = 0.95 * self.delta_min + 0.05 * self.delta_max
        min_extreme_val2 = 0.95 * self.delta_max + 0.05 * self.delta_min
        one_sided_noise1 = torch.rand(size) * (
                min_extreme_val1 - self.delta_min) + self.delta_min
        one_sided_noise2 = torch.rand(size) * (
                self.delta_max - min_extreme_val2) + min_extreme_val2
        rnd_idx = torch.rand(size)
        noise = torch.zeros(size)
        noise[rnd_idx < 0.45] += one_sided_noise1[rnd_idx < 0.45]
        noise[(rnd_idx > 0.45) & (rnd_idx < 0.9)] += one_sided_noise2[(rnd_idx > 0.45) & (rnd_idx < 0.9)]
        noise[rnd_idx > 0.9] += uniform_noise[rnd_idx > 0.9]
        # if not (
        #         noise.min().item() >= self.delta_min and noise.max().item() <= self.delta_max):
        #     print("a")
        return noise.to(device)

    @property
    def display_name(self) -> str:
        return "Extreme tails"

class OracleDataMaskerWithSmallTailsDeltaMinMax(OracleDataMaskerWithDeltaMinMax):

    def __init__(self, data_scaler: DataScaler, data_masker: DataCorruptionMasker, dataset_name: str, x_dim: int,
                 z_dim: int,
                 delta_min: float = 1, delta_max: float = 1):
        super().__init__(data_scaler, data_masker, dataset_name, x_dim, z_dim, delta_min, delta_max)

    @property
    def name(self) -> str:
        return f"{self.base_name}oracle_with_small_tails_delta_min={np.round(self.delta_min, 4)}_delta_max={np.round(self.delta_max, 4)}"

    def generate_noise(self, size: int, device):
        normal_noise = 0.05 * torch.randn(size)
        normal_noise = torch.clip(normal_noise, min=-0.17, max=0.17)
        normal_noise = (normal_noise + 0.17) / 0.34
        normal_noise = normal_noise * (
                self.delta_max - self.delta_min) + self.delta_min
        return normal_noise.to(device)

    @property
    def display_name(self) -> str:
        return "Small tails"

class OracleDataMaskerWithBetaUDeltaMinMax(OracleDataMaskerWithDeltaMinMax):

    def __init__(self, data_scaler: DataScaler, data_masker: DataCorruptionMasker, dataset_name: str, x_dim: int,
                 z_dim: int,
                 delta_min: float = 1, delta_max: float = 1):
        super().__init__(data_scaler, data_masker, dataset_name, x_dim, z_dim, delta_min, delta_max)

    @property
    def name(self) -> str:
        return f"{self.base_name}oracle_with_beta_u_delta_min={np.round(self.delta_min, 4)}_delta_max={np.round(self.delta_max, 4)}"

    def generate_noise(self, size: int, device):
        m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
        noise = m.sample(sample_shape=torch.Size([size])).to(device).squeeze()
        noise = noise * (
                self.delta_max - self.delta_min) + self.delta_min
        return noise.to(device)

    @property
    def display_name(self) -> str:
        return "Beta U"


class OracleDataMaskerWithBetaRightDeltaMinMax(OracleDataMaskerWithDeltaMinMax):

    def generate_noise(self, size: int, device):
        m = Beta(torch.tensor([5.]), torch.tensor([1.]))
        noise = m.sample(sample_shape=torch.Size([size])).to(device).squeeze()
        noise = noise * (
                self.delta_max - self.delta_min) + self.delta_min
        return noise.to(device)

    def __init__(self, data_scaler: DataScaler, data_masker: DataCorruptionMasker, dataset_name: str, x_dim: int,
                 z_dim: int,
                 delta_min: float = 1, delta_max: float = 1):
        super().__init__(data_scaler, data_masker, dataset_name, x_dim, z_dim, delta_min, delta_max)

    @property
    def name(self) -> str:
        return f"{self.base_name}oracle_with_beta_right_delta_min={np.round(self.delta_min, 4)}_delta_max={np.round(self.delta_max, 4)}"

    @property
    def display_name(self) -> str:
        return "Beta right"


class OracleDataMaskerWithBetaLeftDeltaMinMax(OracleDataMaskerWithDeltaMinMax):

    def __init__(self, data_scaler: DataScaler, data_masker: DataCorruptionMasker, dataset_name: str, x_dim: int,
                 z_dim: int,
                 delta_min: float = 1, delta_max: float = 1):
        super().__init__(data_scaler, data_masker, dataset_name, x_dim, z_dim, delta_min, delta_max)

    @property
    def name(self) -> str:
        return f"{self.base_name}oracle_with_beta_left_delta_min={np.round(self.delta_min, 4)}_delta_max={np.round(self.delta_max, 4)}"

    def generate_noise(self, size: int, device):
        m = Beta(torch.tensor([1.]), torch.tensor([5.]))
        noise = m.sample(sample_shape=torch.Size([size])).to(device).squeeze()
        noise = noise * (
                self.delta_max - self.delta_min) + self.delta_min
        return noise.to(device)
    @property
    def display_name(self) -> str:
        return "Beta left"

class OracleDataMaskerWithBetaRightHillDeltaMinMax(OracleDataMaskerWithDeltaMinMax):

    def __init__(self, data_scaler: DataScaler, data_masker: DataCorruptionMasker, dataset_name: str, x_dim: int,
                 z_dim: int,
                 delta_min: float = 1, delta_max: float = 1):
        super().__init__(data_scaler, data_masker, dataset_name, x_dim, z_dim, delta_min, delta_max)

    @property
    def name(self) -> str:
        return f"{self.base_name}oracle_with_beta_right_hill_delta_min={np.round(self.delta_min, 4)}_delta_max={np.round(self.delta_max, 4)}"

    def generate_noise(self, size: int, device):
        m = Beta(torch.tensor([5.]), torch.tensor([2.]))
        noise = m.sample(sample_shape=torch.Size([size])).to(device).squeeze()
        noise = noise * (
                self.delta_max - self.delta_min) + self.delta_min
        return noise.to(device)
    @property
    def display_name(self) -> str:
        return "Beta right hill"

class OracleDataMaskerWithBetaLeftHillDeltaMinMax(OracleDataMaskerWithDeltaMinMax):

    def __init__(self, data_scaler: DataScaler, data_masker: DataCorruptionMasker, dataset_name: str, x_dim: int,
                 z_dim: int,
                 delta_min: float = 1, delta_max: float = 1):
        super().__init__(data_scaler, data_masker, dataset_name, x_dim, z_dim, delta_min, delta_max)

    @property
    def name(self) -> str:
        return f"{self.base_name}oracle_with_beta_left_hill_delta_min={np.round(self.delta_min, 4)}_delta_max={np.round(self.delta_max, 4)}"

    def generate_noise(self, size: int, device):
        m = Beta(torch.tensor([2.]), torch.tensor([5.]))
        noise = m.sample(sample_shape=torch.Size([size])).to(device).squeeze()
        noise = noise * (
                self.delta_max - self.delta_min) + self.delta_min
        return noise.to(device)
    @property
    def display_name(self) -> str:
        return "Beta left hill"