import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as td
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from pbb.models import NNet4l, trainNNet_NTK,CNNet4l, ProbNNet4l, ProbCNNet4l, ProbCNNet9l, CNNet9l, CNNet13l, ProbCNNet13l, ProbCNNet15l, CNNet15l, trainNNet, testNNet, Lambda_var, trainPNNet, computeRiskCertificates, testPosteriorMean, testStochastic, testEnsemble
from pbb.bounds import PBBobj
from pbb import data

def runexp(name_data, objective, prior_type, model, sigma_prior, pmin, learning_rate, momentum, 
learning_rate_prior=0.01, momentum_prior=0.95, delta=0.025, layers=9, delta_test=0.01, mc_samples=1000, 
samples_ensemble=100, kl_penalty=1, initial_lamb=6.0, train_epochs=100, prior_dist='gaussian', 
verbose=False, device='cuda', prior_epochs=20, dropout_prob=0.2, perc_train=1.0, verbose_test=False, 
perc_prior=0.2, batch_size=250,shot_per_class=60):
    # this makes the initialised prior the same for all bounds
    torch.manual_seed(7)
    np.random.seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    loader_kargs = {'num_workers': 1,
                    'pin_memory': True} if torch.cuda.is_available() else {}

    train, test = data.loaddataset(name_data)
    rho_prior = math.log(math.exp(sigma_prior)-1.0)
    
    prior_train = [] 
    prior_test = [] 
    prior_trainls = [] 
    if prior_type == 'rand':
        dropout_prob = 0.0

    # initialise model
    if model == 'cnn':
        if name_data == 'cifar10':
            # only cnn models are tested for cifar10, fcns are only used 
            # with mnist
            if layers == 9:
                net0 = CNNet9l(dropout_prob=dropout_prob).to(device)
            elif layers == 13:
                net0 = CNNet13l(dropout_prob=dropout_prob).to(device)
            elif layers == 15:
                net0 = CNNet15l(dropout_prob=dropout_prob).to(device)
            else: 
                raise RuntimeError(f'Wrong number of layers {layers}')
        else:
            net0 = CNNet4l(dropout_prob=dropout_prob).to(device)
    else:
        net0 = NNet4l(dropout_prob=dropout_prob, device=device).to(device)

    if prior_type == 'rand':
        train_loader, test_loader, _, val_bound_one_batch, _, val_bound = data.loadbatches(
            train, test, loader_kargs, batch_size, prior=False, perc_train=perc_train, perc_prior=perc_prior,shot_per_class = shot_per_class)
        errornet0 = testNNet(net0, test_loader, device=device)
    elif prior_type == 'learnt':
        train_loader, test_loader, valid_loader, val_bound_one_batch, _, val_bound = data.loadbatches(
            train, test, loader_kargs, batch_size, prior=True, perc_train=perc_train, perc_prior=perc_prior,shot_per_class=shot_per_class)
        optimizer = optim.SGD(
            net0.parameters(), lr=learning_rate_prior, momentum=momentum_prior)
        for epoch in trange(prior_epochs):
            traine01,trainls = trainNNet(net0, optimizer, epoch, valid_loader,
                      device=device, verbose=verbose)
            if True :
                errornet0 = testNNet(net0, test_loader, device=device)
                prior_train.append(traine01)
                prior_test.append(errornet0) 
                prior_trainls.append(trainls)     
    

    grads_left_x,ll,rr= trainNNet_NTK(net0, 1, val_bound_one_batch, device=device, verbose=verbose)
    NTK_x = ((torch.einsum('nc,mc->nm', [grads_left_x, grads_left_x])))
    return prior_train,prior_trainls,prior_trainls,NTK_x,ll,rr,NTK_x

def count_parameters(model): 
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
