import torch
from torch.distributions import multivariate_normal
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
from scipy.linalg import orth
import mine
import argparse
import os
from sklearn import datasets as dsets
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from pactl.projectors import IDModule, RoundedDoubleKron, SparseOperator

from copy import deepcopy

# build custom module for logistic regression
class FC(torch.nn.Module):    
    # build the constructor
    def __init__(self, n_inputs, n_middle, n_outputs, C):
        super().__init__()
        self.input_dim = n_inputs
        self.output_dim = n_outputs
        self.middle_dim = n_middle

        self.fc1 = torch.nn.Linear(n_inputs, n_middle)
        self.fc2 = torch.nn.Linear(n_middle, n_outputs)

        self.b_act = -torch.log(torch.exp(torch.Tensor([C])) - 1)

        self.num_parameters = (n_inputs+1) * n_middle + (n_middle+1) * n_outputs

    # make predictions
    def forward(self, x):
        x1 = self.fc1(x)
        x2 = torch.relu(x1)
        x3 = self.fc2(x2)
        x4 = torch.sigmoid(x3 + self.b_act)
        return x4
    
class FCprime(torch.nn.Module):
    def __init__(self, n_inputs, n_middle, n_outputs, C, W):
        super().__init__()    

        self.input_dim = n_inputs
        self.output_dim = n_outputs
        self.middle_dim = n_middle

        self.fc1 = torch.nn.Linear(n_inputs, n_middle)
        self.fc2 = torch.nn.Linear(n_middle, n_outputs)

        self.b_act = -torch.log(torch.exp(torch.Tensor([C])) - 1)

        self.num_parameters = (n_inputs+1) * n_middle + (n_middle+1) * n_outputs

        self.W = W

    def forward(self, x):
        Wd = self.W

        bias_start = self.input_dim * self.middle_dim + self.middle_dim * self.output_dim
        W1 = Wd[:self.input_dim*self.middle_dim].reshape(self.input_dim, self.middle_dim)
        W2 = Wd[self.input_dim*self.middle_dim:bias_start].reshape(self.middle_dim, self.output_dim)
        
        b1 = Wd[bias_start:bias_start+self.middle_dim]
        b2 = Wd[bias_start+self.middle_dim:]

        x1 = torch.matmul(x, W1) + b1
        x2 = torch.relu(x1)
        x3 = torch.matmul(x2, W2) + b2
        x4 = torch.sigmoid(x3 + self.b_act)

        return x4


def evaluate(model, n_inputs, data_loader, criterion):
    model.eval()
    running_loss = 0.0
    running_accuracy = 0.0
    running_risk = 0.0

    with torch.no_grad():
        for (data, target) in data_loader:
            outputs = model(data.view(-1, n_inputs))
            predicted = torch.round(outputs.data).view(-1)

            running_accuracy += (predicted == target).float().mean().item()
            
            running_risk += (predicted != target).float().mean().item()
            
            loss = criterion(outputs.view(-1), target.float())
            running_loss += loss.item()

        running_loss /= len(data_loader)
        running_accuracy /= len(data_loader)
        running_risk /= len(data_loader)
        return running_loss, running_accuracy, running_risk


class GaussianDataset(Dataset):
    def __init__(self, x,y, transform=None):
        self.data = torch.from_numpy(y).float().to(device)
        self.target = torch.from_numpy(x).float().to(device)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        return x,y
    def __len__(self):
        return len(self.data)


def estimate_mine(X, Y):
    lr_re = 2e-4
    lr_pa = 1e-3
    batch_size = 64
    n_epoch = 200
    dimX = X.shape[1] 
    dimY = Y.shape[1] 

    #make loader    
    trainset = GaussianDataset(X,Y)

    trainloader = DataLoader(
                    trainset,
                    batch_size=batch_size,
                    shuffle=True,
                    num_workers=0
                )

    #Parallel
    H = 100

    net = mine.Net_S(dimX, dimY, H)
    if torch.cuda.is_available():
        net.cuda()
    
    MI = mine.mine(trainloader, net, n_epoch, lr_pa)
    return MI


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.cuda.empty_cache()
    verbose = True

    parser = argparse.ArgumentParser()
    parser.add_argument("--n", type=int, default=2000)
    parser.add_argument("--nproj", type=int, default=5)
    parser.add_argument("--nruns", type=int, default=1)

    # setting hyperparameters
    args = parser.parse_args()
    n = args.n  # number of training samples
    nproj = args.nproj  # number of projections
    nruns = args.nruns  # number of runs per projections (useful to estimate MI)

    dataset = "cifar10"  # or "mnist"

    list_C = [1]
    

    if dataset == "mnist":
        n_inputs = 28*28 # makes a 1D vector of 784
        n_outputs = 1
        ref_model = FC(n_inputs, 100, n_outputs, C=1)
        D = ref_model.num_parameters
        dimX = n_inputs
        list_d = [0, 1, 10, 100, 500, 1000, 5000]
        list_factor = [0.0, 0.01, 0.05, 0.1, 0.5, 1., 10.]
    elif dataset == "cifar10":
        n_inputs = 3*32*32 # makes a 1D vector of 784
        n_outputs = 1
        ref_model = FC(n_inputs, 100, n_outputs, C=1)
        D = ref_model.num_parameters
        dimX = n_inputs
        list_d = [0, 10, 100, 500, 1000, 5000, 10000]
        list_factor = [0.0, 0.01, 0.1, 1.]
    

    for d in list_d:
        for C in list_C:
            name_folder = os.path.join('smi_compression', 'results', 'nn', 'rate_dis_proj_C={}'.format(C), dataset)
            if not os.path.exists(name_folder):
                os.makedirs(name_folder)
            for factor in list_factor:
                print("Factor: {}".format(factor))
                # instantiate the model
                # defining loss
                criterion = torch.nn.BCELoss(reduction='mean')
                # criterion = torch.nn.BCEWithLogitsLoss()
                epochs = 30

                train_accs = np.zeros((nproj, nruns))
                test_accs = np.zeros((nproj, nruns))    
                train_risks = np.zeros((nproj, nruns))
                test_risks = np.zeros((nproj, nruns))
                gen_error = np.zeros((nproj, nruns))

                train_risks_d = np.zeros((nproj, nruns))

                mi_bound = np.zeros(nproj)
                smi_bound = np.zeros(nproj)
                dist_w = np.zeros((nproj, nruns))

                if dataset == "mnist":
                    transform=transforms.Compose([transforms.ToTensor(),
                                                # transforms.Normalize((0.,), (D,)),
                                                ])
                    # loading training data
                    class_1 = 1
                    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

                    idx = (train_dataset.targets==0) | (train_dataset.targets==1) | (train_dataset.targets==8)
                    train_dataset.targets = train_dataset.targets[idx]
                    train_dataset.targets = torch.Tensor([1 if i == class_1 else 0 for i in train_dataset.targets])
                    train_dataset.data = train_dataset.data[idx]

                    # loading test data
                    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
                    idx = (test_dataset.targets==0) | (test_dataset.targets==1) | (test_dataset.targets==8) | (test_dataset.targets==2)
                    test_dataset.targets = test_dataset.targets[idx]
                    test_dataset.targets = torch.Tensor([1 if i == class_1 else 0 for i in test_dataset.targets])
                    test_dataset.data = test_dataset.data[idx]

                    # load train and test data samples into dataloader
                    ntrain = n
                    train_dataset.data = train_dataset.data[:ntrain]
                    train_dataset.targets = train_dataset.targets[:ntrain]

                    ntest = int(ntrain) 
                    test_dataset.data = test_dataset.data[:ntest]
                    test_dataset.targets = test_dataset.targets[:ntest]

                    bsize = 128
                elif dataset == "cifar10":
                    transform=transforms.Compose([transforms.ToTensor(),
                                                # transforms.Normalize((0.,), (D,)),
                                                ])
                    # loading training data
                    class_1 = 1  #  0 airplane 1 automobile 2 bird 3 cat 4 deer 5 dog 6 frog 7 horse 8 ship 9 truck

                    train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
                    train_dataset.targets = torch.Tensor(train_dataset.targets)
                    idx = (train_dataset.targets==class_1) | (train_dataset.targets==3) | (train_dataset.targets==4)
                    train_dataset.targets = train_dataset.targets[idx]
                    
                    train_dataset.targets = torch.Tensor([1 if i == class_1 else 0 for i in train_dataset.targets])
                    train_dataset.data = train_dataset.data[idx]

                    # loading test data
                    test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
                    test_dataset.targets = torch.Tensor(test_dataset.targets)
                    idx = (test_dataset.targets==class_1) | (test_dataset.targets==3) | (test_dataset.targets==4) | (test_dataset.targets==9)
                    test_dataset.targets = test_dataset.targets[idx]
                    test_dataset.targets = torch.Tensor([1 if i == class_1 else 0 for i in test_dataset.targets])
                    test_dataset.data = test_dataset.data[idx]

                    # load train and test data samples into dataloader
                    ntrain = n
                    train_dataset.data = train_dataset.data[:ntrain]
                    train_dataset.targets = train_dataset.targets[:ntrain]

                    ntest = int(ntrain) 
                    test_dataset.data = test_dataset.data[:ntest]
                    test_dataset.targets = test_dataset.targets[:ntest]

                    bsize = 128

                train_loader = DataLoader(dataset=train_dataset, batch_size=bsize, shuffle=True) 
                test_loader = DataLoader(dataset=test_dataset, batch_size=bsize, shuffle=False)

                for l in range(nproj):
                    print("Proj ", l+1)
                    
                    data = np.zeros((nruns, dimX+1))
                    for k in range(nruns):
                        print("\t Run ", k+1)

                        # define model
                        model = FC(n_inputs, 100, n_outputs, C=C)

                        if d > 0:
                            P = SparseOperator(D, d, params=None, names=None, seed=l)
                            # To test orthonormality: decomment 
                            # P = P.to_dense()
                            # P /= np.sqrt(D)
                            # print(P.shape)
                            # aux = P.T @ P
                            # print(aux)

                            # defining the optimizer
                            optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
                        else:
                            optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=factor)                            
                        
                        for epoch in range(epochs):
                            for i, (features, labels) in enumerate(train_loader):
                                def closure():
                                    if torch.is_grad_enabled():
                                        optimizer.zero_grad()
                                        outputs = model(features.view(-1, n_inputs))
                                        loss = criterion(outputs.view(-1), labels.float())
                                        # Regularization
                                        if factor > 0:
                                            current_w = torch.cat([param.view(-1) for param in model.parameters()])
                                            if d > 0:
                                                wprime = P @ (P.T @ current_w)
                                                loss += factor * torch.linalg.norm(current_w - wprime)
                                        if loss.requires_grad:
                                            loss.backward()
                                        return loss
                                optimizer.step(closure)
                                # Make loss Lipschitz
                                for param in model.parameters():
                                    norm = param.norm(2, dim=0, keepdim=True)   
                                    desired = torch.clamp(norm, 0, 1/np.sqrt(dimX))
                                    param = param * (desired / (1e-8 + norm))
                                            
                                # calculate the loss again for monitoring
                                output = model(features.view(-1, n_inputs))
                                loss = closure()
               
                            _, current_test_acc, _ = evaluate(model, n_inputs, test_loader, criterion)
                            if verbose:
                                print('Epoch: {}. Train loss: {}. Test accuracy: {}'.format(epoch, loss.item(), current_test_acc))
                        # Compute training/test errors and generalization error
                        train_loss, train_acc, train_risk = evaluate(model, n_inputs, train_loader, criterion)
                        test_loss, test_acc, test_risk = evaluate(model, n_inputs, test_loader, criterion)

                        train_accs[l, k] = train_acc
                        test_accs[l, k] = test_acc

                        train_risks[l, k] = train_loss
                        test_risks[l, k] = test_loss

                        gen_error[l, k] = test_loss - train_loss

                        # Store data and weights to compute (S)MI-based bounds
                        data[k] = torch.cat((features[0].view(-1), labels[0].view(-1))).detach().numpy()
                        weights = torch.cat([param.view(-1) for param in model.parameters()]).detach() #.numpy()
                        
                        weights = weights.unsqueeze(-1)

                        if d > 0:
                            th_w_prime = P @ (P.T @ weights)
                            th_w_prime = th_w_prime
                            dist_w[l, k] = np.linalg.norm(weights - th_w_prime.numpy())

                        # Further tightening
                        # model_prime = FCprime(n_inputs, 100, n_outputs, C=C, W=th_w_prime)
                        # train_loss_d, _, _ = evaluate(model, n_inputs, train_loader, criterion)
                        # train_risks_d[l, k] = train_loss_d


                # Save results
                np.savez(os.path.join(name_folder, 'res_D={}_d={}_n={}_C={}_factor={}_nproj={}_nruns={}'.format(D, d, n, C, factor, nproj, nruns)),
                         mi_bound=mi_bound, smi_bound=smi_bound, dist_w=dist_w, C=C, gen_error=gen_error, train_risks=train_risks, test_risks=test_risks, train_risks_d=train_risks_d, factor=factor, d=d, D=D)

                print('Mean train risk: {}, Mean test risk: {}'.format(train_risks.mean(), test_risks.mean()))
                print('Mean train acc: {}, Mean test acc: {}'.format(train_accs.mean(), test_accs.mean()))
                print("Mean generalization error \n", gen_error.mean())
                print("Mean distance term: ", dist_w.mean())
