import numpy as np
import scipy.stats as ss
from scipy.linalg import orth
from sklearn.linear_model import LogisticRegression
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader

import knnie
import mine

import os


def generate_data(nsample, dim, mean_params):
    y = np.random.choice(2, size=nsample, replace=True, p=np.ones(2)/2)
    X = np.array([ss.multivariate_normal.rvs(mean=mean_params[i], cov=4*np.eye(dim)) for i in y])
    return X, y


class GaussianDataset(Dataset):
    def __init__(self, x,y, transform=None):
        self.data = torch.from_numpy(y).float().to(device)
        self.target = torch.from_numpy(x).float().to(device)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        return x,y
    def __len__(self):
        return len(self.data)


def estimate_mine(X, Y):
    lr_re = 2e-4
    lr_pa = 1e-3
    batch_size = 64
    n_epoch = 200
    dimX = X.shape[1] 
    dimY = Y.shape[1] 

    #make loader    
    trainset = GaussianDataset(X,Y)

    trainloader = DataLoader(
                    trainset,
                    batch_size=batch_size,
                    shuffle=True,
                    num_workers=0
                )

    #Parallel
    H = 100

    net = mine.Net_S(dimX, dimY, H)
    if torch.cuda.is_available():
        net.cuda()
    
    MI = mine.mine(trainloader, net, n_epoch, lr_pa)
    return MI


if __name__ == "__main__":
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    name_folder = os.path.join('smi_compression', 'results', 'log_reg', 'proj')
    if not os.path.exists(name_folder):
       os.makedirs(name_folder)

    # Parameters
    D = 20
    nsamples = np.array([100, 250, 500, 1000])
    nruns = 30
    nproj = 30
    means = np.array([-np.ones(D), np.ones(D)])
    ds = np.array([1, 5, 10, 15])

    gen_error = np.zeros((len(ds), len(nsamples), nproj, nruns))
    train_risks = np.zeros((len(ds), len(nsamples), nproj, nruns))
    test_risks = np.zeros((len(ds), len(nsamples), nproj, nruns))
    smi_bound = np.zeros((len(ds), len(nsamples), nproj))
    bu_bound = np.zeros((len(ds), len(nsamples)))

    uniform_spacing = False
    kmeans = False
    
    for j in range(len(nsamples)):
        n = nsamples[j]
        ntest = int(n * 2/8) # 80/20 rule
        print("Number of samples : ", n)
        for i in range(len(ds)):
            d = ds[i]
            print("\t Dimension : ", d)
            if d < D:
                weights_j = np.zeros((nproj, nruns, d))
                theta_weights_j = np.zeros((nproj, nruns, D))
                training_data_j = np.zeros((nproj, nruns, n, D+1))
                for l in range(nproj):
                    print("\t Proj ", l+1)
                    Theta = np.random.multivariate_normal(mean=np.zeros(d), cov=np.eye(d), size=D)  
                    Theta = orth(Theta)  # Assumption: Theta.T @ Theta = Identity. Theta is of size D x d
                    for k in range(nruns):
                        if k % 50 == 0:
                            print("\t\t Run ", k)
                        # Generate training and test data
                        Xtrain, ytrain = generate_data(n, D, mean_params=means)
                        Xtest, ytest = generate_data(ntest, D, mean_params=means)
                        # Train logistic regression
                        clf = LogisticRegression(random_state=0, fit_intercept=False).fit(Xtrain @ Theta, ytrain)
                        if kmeans:
                            W = np.hstack((clf.coef_, clf.intercept_[:, None]))
                            W = W.T
                            # nc = int(np.sqrt(d/2))
                            nc = 3
                            kmeans = KMeans(n_clusters=nc, n_init="auto").fit(W)
                            idx = kmeans.predict(W)
                            W_hat = np.array([kmeans.cluster_centers_[k] for k in idx])
                            clf.coef_ = W_hat[:-1].T
                            clf.intercept_ = W_hat[-1]
                        if uniform_spacing:
                            W = np.hstack((clf.coef_, clf.intercept_[:, None]))
                            beta = max(np.abs(W.min()), np.abs(W.max()))  # fix
                            alpha = -beta
                            b = 8
                            S = (beta - alpha) / (2 ** b - 1)
                            W_hat = np.round(W/S)
                            clf.coef_ = W_hat[:, :-1]
                            clf.intercept_ = W_hat[:, -1]
                        # Compute training/test errors and generalization error
                        train_risk = np.array(clf.predict(Xtrain @ Theta) != ytrain, dtype=float).mean()
                        test_risk = np.array(clf.predict(Xtest @ Theta) != ytest, dtype=float).mean()
                        train_risks[i, j, l, k] = train_risk
                        test_risks[i, j, l, k] = test_risk
                        gen_error[i, j, l, k] = test_risk - train_risk
                        # Store samples of data and parameters for mutual information estimation
                        training_data_j[l, k] = np.concatenate((Xtrain, ytrain[:, None]), axis=1)
                        # weights_j[l, k] = np.concatenate((clf.coef_, clf.intercept_[:, None]), axis=1)
                        weights_j[l, k] = clf.coef_
                        theta_weights_j[l, k] = (Theta @ clf.coef_.T)[:, 0]
                    # # Estimate mutual information (for fixed Theta)
                    # mi_bound_knn[i, j] += np.sqrt(knnie.kraskov_mi(weights_j, training_data_j[:, 0, :]) / 2)
                    # Estimate MI with MINE
                    if kmeans:
                        mi_estimate = d * np.log(nc)
                        if uniform_spacing:
                            mi_estimate += nc * b
                        else:
                            print("Warning: The bound might be wrong")
                    elif uniform_spacing:
                        mi_estimate = d * b / n
                    else:
                        mi_estimate = estimate_mine(weights_j[l], training_data_j[l, :, 0, :])
                    smi_bound[i, j, l] = np.sqrt(mi_estimate / 2)
                mi_estimate_bu = estimate_mine(theta_weights_j.reshape(nproj*nruns, D), training_data_j[:, :, 0, :].reshape(nproj*nruns, D+1))
                bu_bound[i, j] = np.sqrt(mi_estimate_bu / 2) 
            else:
                weights_j = np.zeros((nruns, d+1))
                training_data_j = np.zeros((nruns, n, D+1))
                for k in range(nruns):
                    if k % 50 == 0:
                        print("\t\t Run ", k)
                    # Generate training and test data
                    Xtrain, ytrain = generate_data(n, D, mean_params=means)
                    Xtest, ytest = generate_data(ntest, D, mean_params=means)
                    # Train logistic regression
                    clf = LogisticRegression(random_state=0).fit(Xtrain, ytrain)
                    # Compute training/test errors and generalization error
                    train_risk = np.array(clf.predict(Xtrain) != ytrain, dtype=float).mean()
                    test_risk = np.array(clf.predict(Xtest) != ytest, dtype=float).mean()
                    train_risks[i, j, :, k] = train_risk
                    test_risks[i, j, :, k] = test_risk
                    gen_error[i, j, :, k] = test_risk - train_risk
                    # Store samples of data and parameters for mutual information estimation
                    training_data_j[k] = np.concatenate((Xtrain, ytrain[:, None]), axis=1)
                    weights_j[k] = np.concatenate((clf.coef_, clf.intercept_[:, None]), axis=1)
                # # Estimate mutual information with KNN estimator
                # mi_knn = knnie.kraskov_mi(weights_j, training_data_j[:, 0, :])
                # mi_bound_knn[i, j] = np.sqrt(mi_knn / 2)
                # Estimate MI with MINE
                mi_mine = estimate_mine(weights_j, training_data_j[:, 0, :])
                smi_bound[i, j, :] = np.sqrt(mi_mine / 2)

    # Save results
    np.savez(os.path.join(name_folder, 'res_D={}_d={}_n={}_nproj={}_nruns={}'.format(D, ds, nsamples, nproj, nruns)), 
             smi_bound=smi_bound, bu_bound=bu_bound, gen_error=gen_error, train_risks=train_risks, test_risks=test_risks, D=D)
