import numpy as np
import scipy.stats as ss
from scipy.linalg import orth
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

from torch import nn
import torch.optim as optim

from tqdm import tqdm

from pactl.projectors import LazyRandom,IDModule,RoundedKron, FixedNumpySeed, FixedPytorchSeed,RoundedDoubleKron
from pactl.projectors import _delchainattr, _setchainattr
from pactl.quantize_fns import \
    get_random_symbols_and_codebook, \
    get_kmeans_symbols_and_codebook, \
    Quantize, \
    get_message_len, \
    do_arithmetic_encoding


import mine

import time

class QuantizingWrapper(nn.Module):
    def __init__(self, net, quantizer, centroids):
        super().__init__()

        self._forward_net = [net]
        self.d = net.d
        self.subspace_params = nn.Parameter(net.subspace_params.detach().clone(), requires_grad=True)
        self.quantizer = quantizer
        self.centroids = nn.Parameter(centroids, requires_grad=True)

        _delchainattr(self._forward_net[0], "subspace_params")

    def to(self, *args, **kwargs):
        self._forward_net[0].to(*args, **kwargs)
        return super().to(*args, **kwargs)

    def forward(self, *args, **kwargs):
        _setchainattr(
            self._forward_net[0],
            "subspace_params",
            self.quantizer(self.subspace_params, self.centroids),
        )
        return self._forward_net[0](*args, **kwargs)

# Define the neural network
class FCNN(nn.Module):
    def __init__(self, input_dim, output_dim, middle_dim=100):
        super(FCNN, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.middle_dim = middle_dim

        self.fc1 = nn.Linear(input_dim, middle_dim)
        self.fc2 = nn.Linear(middle_dim, middle_dim)
        self.fc3 = nn.Linear(middle_dim, output_dim)

        self.num_parameters = (input_dim*middle_dim + middle_dim*middle_dim + middle_dim*output_dim) + (middle_dim + middle_dim + output_dim)

    def forward(self, x):

        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train(model, train_loader, optimizer, criterion, epoch, num_epochs):
    model.train()
    running_loss = 0.0

    for batch_idx, (data, target) in enumerate(train_loader):
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda().long()
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(train_loader)

def evaluate(model, train_loader, criterion, epoch, num_epochs):
    model.eval()
    running_loss = 0.0
    running_accuracy = 0.0

    with torch.no_grad():
        for batch_idx, (data, target) in  enumerate(train_loader):
            if torch.cuda.is_available():
                data = data.cuda()
                target = target.cuda().long()
            
            output = model(data)

            loss = criterion(output, target)

            #running_accuracy += (output.argmax(1) == target).float().mean().item() # after kimia
            running_accuracy += (output.argmax(1) != target).float().mean().item()

            running_loss += loss.item()

        running_loss /= len(train_loader)
        running_accuracy /= len(train_loader)
        return running_loss, running_accuracy


def load_data(n_samples, name="cifar10"):
    # Transformation applied to the dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Download CIFAR-10 dataset
    if name == "cifar10":
        trainset = datasets.CIFAR10(root='./data', train=True,
                                                download=True, transform=transform)
        testset = datasets.CIFAR10(root='./data', train=False,
                                            download=True, transform=transform)
    elif name == "mnist":
        trainset = datasets.MNIST(root='./data', train=True,
                                                download=True, transform=transform)
        testset = datasets.MNIST(root='./data', train=False,
                                            download=True, transform=transform)
    else:
        raise ValueError("Dataset name not recognized")

    # Extract the data and labels from the dataset
    train_data, train_labels = np.array(trainset.data), np.array(trainset.targets)
    test_data, test_labels = np.array(testset.data), np.array(testset.targets)

    # Randomly shuffle the data and labels
    train_indices = np.random.permutation(len(train_data))
    test_indices = np.random.permutation(len(test_data))

    # Divide the data into training and test sets
    Xtrain, ytrain = train_data[train_indices[:n_samples]], train_labels[train_indices[:n_samples]]
    Xtest, ytest = test_data[test_indices], test_labels[test_indices]

    Xtrain = Xtrain.reshape(Xtrain.shape[0], -1)
    Xtest = Xtest.reshape(Xtest.shape[0], -1)

    return Xtrain, ytrain, Xtest, ytest

class BaseDataset(Dataset):
    def __init__(self, x, y, transform=None):
        self.data = torch.from_numpy(x).float()
        self.target = torch.from_numpy(y).float()
    def __getitem__(self, index):
        x = self.data[index].to(device)
        y = self.target[index].to(device)
        return x,y
    def __len__(self):
        return len(self.data)


def estimate_mine(X, Y):
    lr_re = 2e-4
    lr_pa = 1e-3
    batch_size = 64
    n_epoch = 200
    dimX = X.shape[1] 
    dimY = Y.shape[1] 

    #make loader    
    trainset = BaseDataset(X,Y)

    trainloader = DataLoader(
                    trainset,
                    batch_size=batch_size,
                    shuffle=True,
                    num_workers=0
                )

    #Parallel
    H = 100

    net = mine.Net_S(dimX, dimY, H)
    if torch.cuda.is_available():
        net.cuda()
    
    MI = mine.mine(trainloader, net, n_epoch, lr_pa)
    return MI


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset_name = "cifar10" # or "mnist"
    Xtrain, ytrain, Xtest, ytest = load_data(5, dataset_name)

    print('Xtrain size:', Xtrain.shape)
    print('ytrain:', ytrain)

    print('Xtest size:', Xtest.shape)
    print('ytest:', ytest)
    
    # Parameters
    input_dim = Xtrain.shape[-1]
    print("input_dim", input_dim)
    output_dim = 10
    print("output_dim", output_dim)
    output_dim = int(output_dim)

    nsamples = np.array([50000]) 
    nruns = 1
    nproj = 20
    ds = np.array([2500, 12500, 15000])

    DEBUG = True
    num_epochs = 30
    batch_size = 128
    lr = 0.01
    hidden_dim=200

    gen_error = np.zeros((len(ds), len(nsamples), nproj, nruns))
    train_risks = np.zeros((len(ds), len(nsamples), nproj, nruns))
    test_risks = np.zeros((len(ds), len(nsamples), nproj, nruns))
    mi_bound_mine = np.zeros((len(ds), len(nsamples), nproj))

    reference_model = FCNN(input_dim, output_dim, hidden_dim) # .to(device)
    D = reference_model.num_parameters # This is the number Theta has to conform to
    print("[!] D", D, "ds", ds, "nsamples", nsamples, "nruns", nruns, "nproj", nproj)

    PREFIX = "QUANT3_{}_deb_{}_nruns{}_nproj{}_epochs{}_batch_size{}_lr{}_inputdim{}_outputdim{}_hiddendim{}_firstd{}".format(dataset_name, DEBUG, nruns, nproj, num_epochs, batch_size, lr, input_dim, output_dim, hidden_dim, ds[0])
    total_num_runs = len(nsamples) * nproj * nruns
    print("Total number of runs", total_num_runs)
    run_times = []

    for j in tqdm(list(range(len(nsamples)))):
        n = nsamples[j]
        print("Number of samples : ", n)
        for i in range(len(ds)):
            d = ds[i]
            print("\t Dimension : ", d)
            if d < D:
                for l in range(nproj):
                    print("\t Proj ", l+1)

                    weights_j = np.zeros((nruns, d))
                    training_data_j = np.zeros((nruns, 1, input_dim+1)) # this includes the labels

                    if d <= 0:
                        assert(False)

                    for k in range(nruns):
                        time_now = time.time()

                        if k % 10 == 0:
                            print("\t\t Run ", k)
                            np_run_times = np.array(run_times)
                            print("\t\t average run time: ", np.mean(np_run_times), ' of ', len(np_run_times), "/", total_num_runs)

                        # Generate training and test data
                        Xtrain, ytrain, Xtest, ytest = load_data(n, dataset_name) 
                        # Train 
                        trainset = BaseDataset(Xtrain, ytrain)
                        train_loader = DataLoader(
                                        trainset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=0
                                    )

                        testset = BaseDataset(Xtest, ytest)
                        test_loader = DataLoader(
                                        testset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=0
                                    )
                        

                        projector = RoundedDoubleKron
                        model = FCNN(input_dim, output_dim, hidden_dim).to(device)
                        if DEBUG:
                            print("seed l", l)
                        model = IDModule(model,projector,d, seed=l).to(device) # , seed=l).to(device) # seed is projector

                        use_kmeans = False
                        levels = 7
                        cluster_fn = get_kmeans_symbols_and_codebook if use_kmeans else get_random_symbols_and_codebook
                        _, centroids = cluster_fn(model.subspace_params.cpu().data.numpy(), levels=levels, codebook_dtype=np.float16)
                        centroids = torch.from_numpy(centroids).float()
                        
                        qw = QuantizingWrapper(model, quantizer=Quantize().apply, centroids=centroids).to(device) # .to(device_id)

                        assert(qw.d == d)
                        assert(len(qw.subspace_params.reshape(-1).detach().cpu().numpy()) == d)

                        criterion = nn.CrossEntropyLoss()
                        optimizer = optim.Adam(qw.parameters(), lr=lr)

                        for epoch in range(num_epochs):
                            loss = train(qw, train_loader, optimizer, criterion, epoch, num_epochs)
                            if DEBUG:
                                print("Epoch", epoch, "loss", loss, "time", (time.time() - time_now)/(epoch+1))
                        
                        # Compute training/test errors and generalization error
                        _, train_risk = evaluate(qw, train_loader, criterion, epoch, num_epochs)
                        test_loss, test_risk = evaluate(qw, test_loader, criterion, epoch, num_epochs)
                        if DEBUG:
                            print("Train risk", train_risk, "Test risk", test_risk, "(test loss)", test_loss)

                        quantized_vec = qw.quantizer(qw.subspace_params, qw.centroids)
                        quantized_vec = quantized_vec.cpu().detach().numpy()
                        vec = (qw.centroids.unsqueeze(-2) - qw.subspace_params.unsqueeze(-1))**2.0
                        symbols = torch.min(vec, -1)[-1]
                        symbols = symbols.cpu().detach().numpy()
                        centroids = qw.centroids.cpu().detach().numpy()
                        probabilities = np.array([np.mean(symbols == i) for i in range(levels)])
                        _, coded_symbols_size = do_arithmetic_encoding(symbols, probabilities,
                                                                        qw.centroids.shape[0])
                        message_len = get_message_len(
                            coded_symbols_size=coded_symbols_size,
                            codebook=centroids,
                            max_count=len(symbols),
                        )

                        print("message_len", message_len)
                        
                        
                        train_risks[i, j, l, k] = train_risk
                        test_risks[i, j, l, k] = test_risk
                        gen_error[i, j, l, k] = test_risk - train_risk
                        # Store samples of data and parameters for mutual information estimation
                        training_data_j[k] = np.concatenate((Xtrain[0], np.array([ytrain[0]])), axis=0)

                        weights_j[k] = qw.subspace_params.reshape(-1).detach().cpu().numpy()
                        if DEBUG:
                            print("subspace_params", qw.subspace_params.reshape(-1).detach().cpu().numpy()[:10])

                        this_run_time = time.time() - time_now
                        run_times.append(this_run_time)

                    mi_mine = message_len / n
                    if DEBUG:
                        print("mi_mine", mi_mine)
                    mi_bound_mine[i, j, l] = np.sqrt(mi_mine / 2)
    print("MI bound with MINE", mi_bound_mine)
    print("Generalization error", gen_error)

    # Save results
    np.savez('{}_results'.format(PREFIX), mi_bound_mine=mi_bound_mine, gen_error=gen_error, train_risks=train_risks, test_risks=test_risks, ds=ds, nsamples=nsamples)

    # Plot
    plt.figure()
    mean_gen_error = gen_error.mean(axis=(2,3))
    mean_mi_bound_mine = mi_bound_mine.mean(axis=2)
    for i in range(len(ds)):
        plt.plot(nsamples, mean_gen_error[i], '-o', label="Generalization error for d = {}".format(ds[i]))
        # plt.plot(nsamples, mi_bound_knn[i], '-x', label="ISMI bound for d = {}".format(ds[i]))
        plt.plot(nsamples, mean_mi_bound_mine[i], '-x', label="ISMI bound (MINE) for d = {}".format(ds[i]))
    plt.xlabel(r"$n$")
    plt.legend()
    namefig = "{}_gen_error.pdf".format(PREFIX)
    plt.savefig(namefig)

    plt.figure()
    mean_train_risks = train_risks.mean(axis=(2,3))
    mean_test_risks = test_risks.mean(axis=(2,3))
    for i in range(len(ds)):
        plt.plot(nsamples, mean_train_risks[i], '-o', label="Training acc for d = {}".format(ds[i]))
        plt.plot(nsamples, mean_test_risks[i], '--o', label="Test acc for d = {}".format(ds[i]))

    plt.xlabel(r"$n$")
    plt.legend()
    namefig = "{}_train_test_acc.pdf".format(PREFIX)
    plt.savefig(namefig)
