import torch
from torch.distributions.multivariate_normal import MultivariateNormal

from .base import Data, IndexTensorDataset


class Toy(Data):
    name = "toy"
    feat_dim = 2
    sens_dim = 2
    simple_sens_cols = [0, 1]

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def setup(self, stage: str) -> None:
        # See Lohaus, Michael, Michael Perrot, and Ulrike Von Luxburg. "Too relaxed to be fair."
        # International Conference on Machine Learning. PMLR, 2020.

        feat = torch.cat([
            # Group 1 - negatives
            MultivariateNormal(torch.tensor([2., -2.]), torch.eye(2).float()).sample((150,)),

            # Group 1 - positives
            1/2 * (MultivariateNormal(torch.tensor([3., -1.]), torch.eye(2).float()).sample((150,))
                   + MultivariateNormal(torch.tensor([1., 4.]), torch.eye(2) * 0.5).sample((150,))),

            # Group 2 - positives
            MultivariateNormal(torch.tensor([2.5, 2.5]), torch.eye(2).float()).sample((150,)),

            # Group 2 - negatives
            MultivariateNormal(torch.tensor([4.5, -1.5]), torch.eye(2).float()).sample((150,))
        ], dim=0).float()

        sens = torch.tile(torch.eye(2), dims=(300, 1)).float()
        labels = torch.cat([torch.zeros(150), torch.ones(300), torch.zeros(150)]).long()

        self.train_data = IndexTensorDataset(feat, sens, labels)
        self.val_data = IndexTensorDataset(feat, sens, labels)
        self.test_data = IndexTensorDataset(feat, sens, labels)
