import numpy as np
import scipy.stats as ss
from scipy.linalg import orth
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

from torch import nn
import torch.optim as optim

from tqdm import tqdm

#import knnie
import mine

import time

# Define the neural network
class FCNN(nn.Module):
    def __init__(self, d, input_dim, output_dim, middle_dim=100):
        super(FCNN, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.middle_dim = middle_dim

        self.fc1 = nn.Linear(input_dim, middle_dim)
        self.fc2 = nn.Linear(middle_dim, middle_dim)
        self.fc3 = nn.Linear(middle_dim, output_dim)

        self.d = d

        self.num_parameters = (input_dim*middle_dim + middle_dim*middle_dim + middle_dim*output_dim) + (middle_dim + middle_dim + output_dim)
        self.actual_parameters = torch.nn.parameter.Parameter(data=torch.rand(self.d,1), requires_grad=True)

    def forward(self, x, Theta):
        if Theta.shape[0] !=  self.num_parameters:
            print("[!] Theta has incorrect shape, should be [{}, 1]".format(self.num_parameters))
            assert(False)
        W = torch.matmul(Theta, self.actual_parameters).view(-1)

        bias_start = self.input_dim*self.middle_dim + self.middle_dim*self.middle_dim + self.middle_dim*self.output_dim
        W1 = W[:self.input_dim*self.middle_dim].reshape(self.input_dim, self.middle_dim)
        W2 = W[self.input_dim*self.middle_dim:self.input_dim*self.middle_dim + self.middle_dim*self.middle_dim].reshape(self.middle_dim, self.middle_dim)
        W3 = W[self.input_dim*self.middle_dim + self.middle_dim*self.middle_dim:bias_start].reshape(self.middle_dim, self.output_dim)
        
        b1 = W[bias_start:bias_start + self.middle_dim]
        b2 = W[bias_start + self.middle_dim:bias_start + self.middle_dim + self.middle_dim]
        b3 = W[bias_start + self.middle_dim + self.middle_dim:]

        x = torch.relu(torch.matmul(x, W1) + b1)
        x = torch.relu(torch.matmul(x, W2) + b2)
        x = torch.matmul(x, W3) + b3
        return x

def train(model, train_loader, Theta, optimizer, criterion, epoch, num_epochs):
    model.train()
    running_loss = 0.0

    for batch_idx, (data, target) in enumerate(train_loader):
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda().long()
        
        optimizer.zero_grad()
        output = model(data, Theta)
        loss = criterion(output, target)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(train_loader)

def evaluate(model, train_loader, Theta, criterion, epoch, num_epochs):
    model.eval()
    running_loss = 0.0
    running_accuracy = 0.0

    with torch.no_grad():
        for batch_idx, (data, target) in  enumerate(train_loader):
            if torch.cuda.is_available():
                data = data.cuda()
                target = target.cuda().long()
            
            output = model(data, Theta)

            loss = criterion(output, target)

            running_accuracy += (output.argmax(1) != target).float().mean().item()
            running_loss += loss.item()

        running_loss /= len(train_loader)
        running_accuracy /= len(train_loader)
        return running_loss, running_accuracy


def load_data(n_samples, name="cifar10"):
    # Transformation applied to the dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Download CIFAR-10 dataset
    if name == "cifar10":
        trainset = datasets.CIFAR10(root='./data', train=True,
                                                download=True, transform=transform)
        testset = datasets.CIFAR10(root='./data', train=False,
                                            download=True, transform=transform)
    elif name == "mnist":
        trainset = datasets.MNIST(root='./data', train=True,
                                                download=True, transform=transform)
        testset = datasets.MNIST(root='./data', train=False,
                                            download=True, transform=transform)
    else:
        raise ValueError("Dataset name not recognized")

    # Extract the data and labels from the dataset
    train_data, train_labels = np.array(trainset.data), np.array(trainset.targets)
    test_data, test_labels = np.array(testset.data), np.array(testset.targets)

    # Randomly shuffle the data and labels
    train_indices = np.random.permutation(len(train_data))
    test_indices = np.random.permutation(len(test_data))

    # Divide the data into training and test sets
    Xtrain, ytrain = train_data[train_indices[:n_samples]], train_labels[train_indices[:n_samples]]
    Xtest, ytest = test_data[test_indices], test_labels[test_indices]

    Xtrain = Xtrain.reshape(Xtrain.shape[0], -1)
    Xtest = Xtest.reshape(Xtest.shape[0], -1)

    return Xtrain, ytrain, Xtest, ytest

class BaseDataset(Dataset):
    def __init__(self, x, y, transform=None):
        self.data = torch.from_numpy(x).float()
        self.target = torch.from_numpy(y).float()

    def __getitem__(self, index):
        x = self.data[index].to(device)
        y = self.target[index].to(device)
        return x,y
    def __len__(self):
        return len(self.data)


def estimate_mine(X, Y):
    lr_re = 2e-4
    lr_pa = 1e-3
    batch_size = 64
    n_epoch = 200
    dimX = X.shape[1] 
    dimY = Y.shape[1] 

    #make loader    
    trainset = BaseDataset(X,Y)

    trainloader = DataLoader(
                    trainset,
                    batch_size=batch_size,
                    shuffle=True,
                    num_workers=0
                )

    #Parallel
    H = 100

    net = mine.Net_S(dimX, dimY, H)
    if torch.cuda.is_available():
        net.cuda()
    
    MI = mine.mine(trainloader, net, n_epoch, lr_pa)
    return MI


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset_name = "cifar10" # or "mnist"
    Xtrain, ytrain, Xtest, ytest = load_data(5, dataset_name)
   
    print('Xtrain size:', Xtrain.shape)
    print('ytrain:', ytrain)

    print('Xtest size:', Xtest.shape)
    print('ytest:', ytest)

    # Parameters
    input_dim = Xtrain.shape[-1]
    print("input_dim", input_dim)
    output_dim = 10
    print("output_dim", output_dim)
    output_dim = int(output_dim)

    nsamples = np.array([50000])
    
    nruns = 20
    nproj = 20
    ds = np.array([100, 1000, 9000])

    DEBUG = True
    num_epochs = 20
    batch_size = 128
    lr = 0.01
    hidden_dim=200

    gen_error = np.zeros((len(ds), len(nsamples), nproj, nruns))
    train_risks = np.zeros((len(ds), len(nsamples), nproj, nruns))
    test_risks = np.zeros((len(ds), len(nsamples), nproj, nruns))
    mi_bound_mine = np.zeros((len(ds), len(nsamples), nproj))

    reference_model = FCNN(1, input_dim, output_dim, hidden_dim)
    D = reference_model.num_parameters
    print("[!] D", D, "ds", ds, "nsamples", nsamples, "nruns", nruns, "nproj", nproj)

    PREFIX = "deb_{}_nruns{}_nproj{}_epochs{}_batch_size{}_lr{}_inputdim{}_outputdim{}_hiddendim{}".format(dataset_name, DEBUG, nruns, nproj, num_epochs, batch_size, lr, input_dim, output_dim, hidden_dim)
    
    total_num_runs = len(nsamples) * nproj * nruns
    print("Total number of runs", total_num_runs)
    run_times = []

    for j in tqdm(list(range(len(nsamples)))):
        n = nsamples[j]
        print("Number of samples : ", n)
        for i in range(len(ds)):
            d = ds[i]
            print("\t Dimension : ", d)
            if d < D:
                for l in range(nproj):
                    print("\t Proj ", l+1)

                    weights_j = np.zeros((nruns, d))
                    training_data_j = np.zeros((nruns, n, input_dim+1))

                    if d <= 0:
                        Theta = None
                    else:
                        Theta = np.random.multivariate_normal(mean=np.zeros(d), cov=np.eye(d), size=D)  
                        Theta = orth(Theta)  # Assumption: Theta.T @ Theta = Identity. Theta is of size D x d
                        Theta = torch.from_numpy(Theta).float().to(device)

                    for k in range(nruns):
                        time_now = time.time()

                        if k % 10 == 0:
                            print("\t\t Run ", k)
                            np_run_times = np.array(run_times)
                            print("\t\t average run time: ", np.mean(np_run_times), ' of ', len(np_run_times), "/", total_num_runs)

                        Xtrain, ytrain, Xtest, ytest = load_data(n, dataset_name)

                        # Train 
                        trainset = BaseDataset(Xtrain, ytrain)
                        train_loader = DataLoader(
                                        trainset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=0
                                    )

                        testset = BaseDataset(Xtest, ytest)
                        test_loader = DataLoader(
                                        testset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=0
                                    )
                        

                        model = FCNN(Theta.shape[1], input_dim, output_dim, hidden_dim).to(device)
                        assert(len(model.actual_parameters.view(-1)) == d)

                        criterion = nn.CrossEntropyLoss()
                        optimizer = optim.Adam(model.parameters(), lr=lr)

                        for epoch in range(num_epochs):
                            loss = train(model, train_loader, Theta, optimizer, criterion, epoch, num_epochs)
                            if DEBUG:
                                print("Epoch", epoch, "loss", loss, "time", (time.time() - time_now)/(epoch+1))

                        # Compute training/test errors and generalization error
                        _, train_risk = evaluate(model, train_loader, Theta, criterion, epoch, num_epochs)
                        test_loss, test_risk = evaluate(model, test_loader, Theta, criterion, epoch, num_epochs)
                        if DEBUG:
                                print("Train risk", train_risk, "Test risk", test_risk, "(test loss)", test_loss)

                        train_risks[i, j, l, k] = train_risk
                        test_risks[i, j, l, k] = test_risk
                        gen_error[i, j, l, k] = test_risk - train_risk
                        # Store samples of data and parameters for mutual information estimation
                        training_data_j[k] = np.concatenate((Xtrain[0], np.array([ytrain[0]])), axis=0)
                        weights_j[k] = model.actual_parameters.reshape(-1).detach().cpu().numpy()

                        this_run_time = time.time() - time_now
                        run_times.append(this_run_time)

                    
                    # Estimate MI with MINE
                    mi_mine = estimate_mine(weights_j, training_data_j[:,0,:].reshape(training_data_j.shape[0], -1))
                    if DEBUG:
                        print("mi_mine", mi_mine)
                    mi_bound_mine[i, j, l] = np.sqrt(mi_mine / 2)

    print("MI bound with MINE", mi_bound_mine)
    print("Generalization error", gen_error)

    # Save results
    np.savez('{}_results'.format(PREFIX), mi_bound_mine=mi_bound_mine, gen_error=gen_error, train_risks=train_risks, test_risks=test_risks, ds=ds, nsamples=nsamples)

    # Plot
    plt.figure()
    mean_gen_error = gen_error.mean(axis=(2,3))
    mean_mi_bound_mine = mi_bound_mine.mean(axis=2)
    for i in range(len(ds)):
        plt.plot(nsamples, mean_gen_error[i], '-o', label="Generalization error for d = {}".format(ds[i]))
        # plt.plot(nsamples, mi_bound_knn[i], '-x', label="ISMI bound for d = {}".format(ds[i]))
        plt.plot(nsamples, mean_mi_bound_mine[i], '-x', label="ISMI bound (MINE) for d = {}".format(ds[i]))
    plt.xlabel(r"$n$")
    plt.legend()
    namefig = "{}_gen_error.pdf".format(PREFIX)
    plt.savefig(namefig)

    plt.figure()
    mean_train_risks = train_risks.mean(axis=(2,3))
    mean_test_risks = test_risks.mean(axis=(2,3))
    for i in range(len(ds)):
        plt.plot(nsamples, mean_train_risks[i], '-o', label="Training acc for d = {}".format(ds[i]))
        plt.plot(nsamples, mean_test_risks[i], '--o', label="Test acc for d = {}".format(ds[i]))

    plt.xlabel(r"$n$")
    plt.legend()
    namefig = "{}_train_test_acc.pdf".format(PREFIX)
    plt.savefig(namefig)
