import numpy as np
import openturns as ot
import matplotlib.pyplot as plt

class copula_generator:
    '''
    This class generates the synthetic data used in the experiments.
    D
    '''
    def __init__(self, D, num_heavy, df, seed=1):
        self.D = D
        self.marginals = []

        # correlation matrix with random 0.25 on off-diag
        np.random.seed(seed)
        self.R = ot.CorrelationMatrix(D)
        for i in range(D): # number of non-zero entries = D
            j = np.random.randint(0, self.D)
            k = np.random.randint(0, self.D)
            if j != k:
                self.R[j, k] = 0.25
                self.R[k, j] = 0.25
        self.copula = ot.NormalCopula(self.R)

        # define the multivariate distribution
        for j in range(self.D):
            if j < int(self.D / 4):
                mean = np.random.rand() * 8 - 4  # sample random mean between -4 and 4
                sd = np.random.rand() + 1  # sample sd from 1 to 2
                self.marginals.append(ot.Normal(mean, sd))
            elif j < int(3 * self.D / 8):
                mean1 = np.random.rand() * 8 - 4
                mean2 = np.random.rand() * 8 - 4
                sd1 = np.random.rand() + 1
                sd2 = np.random.rand() + 1
                weights = [0.5, 0.5]
                mixture_comps = [ot.Normal(mean1, sd1), ot.Normal(mean2, sd2)]
                self.marginals.append(ot.Mixture(mixture_comps, weights))
            elif j < int(self.D / 2):
                mean1 = np.random.rand() * 8 - 4
                mean2 = np.random.rand() * 8 - 4
                mean3 = np.random.rand() * 8 - 4
                sd1 = np.random.rand() + 1
                sd2 = np.random.rand() + 1
                sd3 = np.random.rand() + 1
                weights = [1 / 3, 1 / 3, 1 / 3]
                mixture_comps = [ot.Normal(mean1, sd1), ot.Normal(mean2, sd2), ot.Normal(mean3, sd3)]
                self.marginals.append(ot.Mixture(mixture_comps, weights))

            elif j>=self.D - num_heavy:
                mean1 = np.random.rand() * 8 - 4
                mean2 = np.random.rand() * 8 - 4
                sd1 = np.random.rand() + 1
                sd2 = np.random.rand() + 1
                weights = [0.5, 0.5]
                mixture_comps = [ot.Student(df, mean1, sd1), ot.Student(df, mean2, sd2)]
                self.marginals.append(ot.Mixture(mixture_comps, weights))
            else:
                mean1 = np.random.rand() * 8 - 4
                mean2 = np.random.rand() * 8 - 4
                sd1 = np.random.rand() + 1
                sd2 = np.random.rand() + 1
                weights = [0.5, 0.5]
                mixture_comps = [ot.Normal(mean1, sd1), ot.Normal(mean2, sd2)]
                self.marginals.append(ot.Mixture(mixture_comps, weights))

        self.dist = ot.ComposedDistribution(self.marginals, self.copula)
        self.mv_samps = np.array([])

    def get_data(self, n):
        self.mv_samps = np.array(self.dist.getSample(n))
        return self.mv_samps

    def visualize_marginals(self, save=False, range_x=5):
        fig, axs = plt.subplots(2, int(self.D / 2))
        fig.suptitle("Dependency induced by Gaussian Copula")
        j = 0
        for ax in axs.flat:
            ax.hist(self.mv_samps[:, j], range=(-range_x, range_x), bins=30, density=True)
            j += 1
        if save:
            plt.savefig("plots/marginals_copula")
        else:
            plt.show()

    def get_R(self):
        return(self.R)