import math
import torch
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt  # for plotting stuff
from random import shuffle
from scipy.stats import multivariate_normal  # generating synthetic data

from plot import plot_logistic_regression, plot2d


def generate_synthetic_data(plot_data=False):
    """
        Code for generating the synthetic data.
        We will have two non-sensitive features and one sensitive feature.
        A sensitive feature value of 0.0 means the example is considered to be in protected group (e.g., female) and 1.0 means it's in non-protected group (e.g., male).
    """

    n_samples = 1000  # generate these many data points per class
    disc_factor = math.pi / 4.0  # this variable determines the initial discrimination in the data -- decraese it to generate more discrimination

    def gen_gaussian(mean_in, cov_in, class_label):
        nv = multivariate_normal(mean=mean_in, cov=cov_in)
        X = nv.rvs(n_samples)
        y = np.ones(n_samples, dtype=float) * class_label
        return nv, X, y

    """ Generate the non-sensitive features randomly """
    # We will generate one gaussian cluster for each class
    mu1, sigma1 = [2, 2], [[5, 1], [1, 5]]
    mu2, sigma2 = [-2, -2], [[10, 1], [1, 3]]
    nv1, X1, y1 = gen_gaussian(mu1, sigma1, 1)  # positive class
    nv2, X2, y2 = gen_gaussian(mu2, sigma2, 0)  # negative class

    # join the posisitve and negative class clusters
    X = np.vstack((X1, X2))
    y = np.hstack((y1, y2))

    # shuffle the data
    perm = list(range(0, n_samples * 2))
    shuffle(perm)
    X = X[perm]
    y = y[perm]

    rotation_mult = np.array(
        [[math.cos(disc_factor), -math.sin(disc_factor)], [math.sin(disc_factor), math.cos(disc_factor)]])
    X_aux = np.dot(X, rotation_mult)

    """ Generate the sensitive feature here """
    x_control = []  # this array holds the sensitive feature value
    for i in range(0, len(X)):
        x = X_aux[i]

        # probability for each cluster that the point belongs to it
        p1 = nv1.pdf(x)
        p2 = nv2.pdf(x)

        # normalize the probabilities from 0 to 1
        s = p1 + p2
        p1 = p1 / s
        p2 = p2 / s

        r = np.random.uniform()  # generate a random number from 0 to 1

        if r < p1:  # the first cluster is the positive class
            x_control.append(1.0)  # 1.0 means its male
        else:
            x_control.append(0.0)  # 0.0 -> female

    x_control = np.array(x_control)

    """ Show the data """
    if plot_data:
        num_to_draw = 200  # we will only draw a small number of points to avoid clutter
        x_draw = X[:num_to_draw]
        y_draw = y[:num_to_draw]
        x_control_draw = x_control[:num_to_draw]

        X_s_0 = x_draw[x_control_draw == 0.0]
        X_s_1 = x_draw[x_control_draw == 1.0]
        y_s_0 = y_draw[x_control_draw == 0.0]
        y_s_1 = y_draw[x_control_draw == 1.0]
        plt.scatter(X_s_0[y_s_0 == 1.0][:, 0], X_s_0[y_s_0 == 1.0][:, 1], color='green', marker='x', s=30,
                    linewidth=1.5, label="Prot. +ve")
        plt.scatter(X_s_0[y_s_0 == 0.0][:, 0], X_s_0[y_s_0 == 0.0][:, 1], color='red', marker='x', s=30, linewidth=1.5,
                    label="Prot. -ve")
        plt.scatter(X_s_1[y_s_1 == 1.0][:, 0], X_s_1[y_s_1 == 1.0][:, 1], color='green', marker='o', facecolors='none',
                    s=30, label="Non-prot. +ve")
        plt.scatter(X_s_1[y_s_1 == 0.0][:, 0], X_s_1[y_s_1 == 0.0][:, 1], color='red', marker='o', facecolors='none',
                    s=30, label="Non-prot. -ve")

        plt.tick_params(axis='x', which='both', bottom='off', top='off',
                        labelbottom='off')  # dont need the ticks to see the data distribution
        plt.tick_params(axis='y', which='both', left='off', right='off', labelleft='off')
        plt.legend(loc=2, fontsize=15)
        plt.xlim((-15, 10))
        plt.ylim((-10, 15))
        # plt.savefig("data.png")
        plt.show()

    x_control = {"s1": x_control}  # all the sensitive features are stored in a dictionary
    return X, y, x_control


def read_data():
    data, target, _ = generate_synthetic_data()
    data, target = torch.from_numpy(data.astype(np.float32)), torch.from_numpy(target.astype(np.float32)).unsqueeze(-1)
    flip = (torch.rand(target.size()) > 0.2).float()
    target = (target, target * (1 - flip) + (1 - target) * flip)

    return data, target


class FairDataset(Dataset):
    input_size = 2

    tasks = {
        'right': [0, ['bce', 'acc']],
        'wrong': [1, ['bce', 'acc']]
    }

    def __init__(self, tag):
        super(FairDataset, self).__init__()
        self.data, self.target = read_data()

    def __getitem__(self, index):
        return self.data[index], [t[index] for t in self.target]

    def __len__(self):
        return self.data.size(0)

    # @torch.no_grad()
    # def plot(self, model, tasks):  # TODO tag
    #     rep = model['rep'](self.data)
    #     preds = []
    #     for task_i in tasks:
    #         preds.append(model[task_i.name](rep))
    #
    #     # Logistic regression plots
    #     fig, ax = plt.subplots(len(tasks), 1)
    #     fig.suptitle('Logisitic regresion curves')
    #
    #     for i, task_i in enumerate(tasks):
    #         plot_logistic_regression(rep, self.target[task_i.index], model[task_i.name], alpha=0.01, sigmoid=False,
    #                                  ax=ax[i], xmin=-3, xmax=3)
    #
    #     plt.savefig('plots/plot_lr.png')
    #     plt.show()
    #
    #     # 2D plots
    #     fig, ax = plt.subplots(1, len(tasks) + 1)
    #     fig.suptitle('Predictions')
    #
    #     plot2d(self.data, self.target[0].squeeze(), title='Original', ax=ax[0])
    #
    #     plot2d(self.data, preds[0].squeeze() > 0.5, title='right', ax=ax[1])
    #     plot2d(self.data, preds[1].squeeze() < 0.5, title='wrong', ax=ax[2])
    #
    #     plt.savefig('plots/plot_2d.png')
    #     plt.show()

    @torch.no_grad()
    def plot(self, model, tasks, title):
        rep = model['rep'](self.data)
        preds = []
        for task_i in tasks:
            preds.append(model[task_i.name](rep))

        # Logistic regression plots
        fig, ax = plt.subplots(1, len(tasks), figsize=(3.5 * 2, 3.5))
        # ax = fig.add_subplot(1, 1, 1)

        # plt.rc('font', family='serif')
        plt.rc('xtick', labelsize='x-small')
        plt.rc('ytick', labelsize='x-small')

        vy = torch.tensor([2, 2.]) - torch.tensor([-2., -2])
        dist = torch.einsum('bi,i->b', self.data, vy)

        print(dist.size(), dist.max(), dist.min())

        for i, task_i in enumerate(tasks):
            plot_logistic_regression(dist, self.target[task_i.index], alpha=0.01, sigmoid=False, above=i != 0,
                                     ax=ax[(i) % 2], xmin=-3, xmax=3, title=title,
                                     task_model=torch.nn.Sequential(model['rep'], model[task_i.name]))

        fig.tight_layout()
        # fig.suptitle('Logisitic regression curves')
        plt.savefig('plots/plot_lr.svg')
        plt.show()

        # 2D plots
        fig, ax = plt.subplots(1, len(tasks) + 1)
        fig.suptitle('Predictions')

        plot2d(self.data, self.target[0].squeeze(), title='Original', ax=ax[0])

        plot2d(self.data, preds[0].squeeze() > 0.5, title='right', ax=ax[1])
        plot2d(self.data, preds[1].squeeze() < 0.5, title='wrong', ax=ax[2])

        plt.savefig('plots/plot_2d.png')
        plt.show()