import torch
import torch.nn as nn

import math, os, copy
import numpy as np
from scipy.optimize import bisect

import warnings
warnings.filterwarnings("ignore")

# def bounded_cross_entropy(x, y, eps=-1):
#     pred = F.log_softmax(x, dim=-1)
#     pred = torch.log(torch.exp(pred) + math.exp(eps))
#     return F.nll_loss(pred, y, reduce=False)

def evaluation(model, criterion, dataset, log, lr):
    model.eval()
    device = next(model.parameters()).device
    log.eval(len_dataset=len(dataset.test))
    losses, count = 0, 0
    with torch.no_grad():
        for batch in dataset.test:
            inputs, targets = (b.to(device) for b in batch)
            predictions = model(inputs)
            loss = criterion(predictions, targets)
            correct = torch.argmax(predictions, 1) == targets
            log(None, loss.cpu(), correct.cpu(), lr)
            losses += loss.sum().item()
            count += inputs.shape[0]
    return losses/count

def noise_injection(model, p):
    k = 0
    device = next(model.parameters()).device
    noises, noises_scaled = [], [] 
    for i, param in enumerate(model.parameters()):
        t = len(param.view(-1))
        noises.append(torch.clip(torch.randn(param.data.size(), device=device), min=-2, max=2))

        if torch.is_tensor(p) and p.dim() > 0:
            noises_scaled.append(torch.mul(torch.reshape(torch.exp(p[k:(k+t)]).data, param.data.size()), noises[i]))
        else:
            noises_scaled.append(noises[i]*p)
        
        param.data += noises_scaled[i]
        k += t
    return noises, noises_scaled

def rm_injected_noises(model, noises_scaled):
    for i, param in enumerate(model.parameters()):
        param.data -= noises_scaled[i]
    return

def weight_decay(model, w0):
    k, weights = 0, 0
    for i, param in enumerate(model.parameters()):
        t = len(param.view(-1))
        weights += torch.norm(param.view(-1)-w0[k:(k+t)])**2
        k += t
    return weights

def weight_decay_mulb(model, b, w0):
    # noise injection
    k, weights = 0, 0
    for i, param in enumerate(model.parameters()):
        t = len(param.view(-1))
        weights += torch.norm(param.view(-1)-w0[k:(k+t)])**2*torch.exp(-2*b[i].double())
        k += t
    return weights

def get_kl_term(weight_decay, p, samples=5e4, we=None, layers=1, maxwe=1e5):
    denominator = weight_decay+torch.norm(torch.exp(p))**2
    
    if we is None:
        if maxwe is not None:
            we = torch.clip(len(p)/denominator, max=maxwe)
        else:
            we = len(p)/denominator
    
    kl = 0.5*( (-2*p).sum() -len(p)*torch.log(we) -len(p) + we*denominator )
    return ( 6*(kl+60*layers)/samples )**0.5, kl, 1/we**0.5

def get_kl_term_with_b(weight_decay, p, b):
    d = len(p)
    KL = (torch.exp(-2*b.double())*torch.exp( 2*(p).double() ).sum() /d - 
                        ( 2*(p).double().sum()/d - 2*b.double() + 1 ))
    return (KL * d + weight_decay*torch.exp(-2*b))/2

def get_kl_term_layer_pb(model, wdecay_mulb, p, b):
    k, KL1, KL2 = 0, 0, 0
    for i, param in enumerate(model.parameters()):
        t = len(param.view(-1))
        KL1 += torch.exp(-2*b[i].double())*torch.exp(2*(p[k:(k+t)]).double()).sum()
        KL2 += 2*b[i].double()*t
        k += t

    KL = KL1 - ( 2*(p).double().sum() - KL2 + len(p) )
    return (KL + wdecay_mulb)/2

def kl_term_backward(kl_loss, model, p, noises):
    grad_loss = []
    for i, param in enumerate(model.parameters()): # gradient for p
        grad_loss.append(torch.mul(noises[i], param.grad).view(-1))
    kl_loss.backward() 
    # gradient for p
    k = 0
    for i, param in enumerate(model.parameters()):
        t = len(param.grad.view(-1))
        g = torch.mul(grad_loss[i].view(-1), torch.exp(p.data[k:(k+t)])) 
        p.grad[k:(k+t)] += g
        k += t
    return

def kl_term_backward_mean(kl_loss, model, p, noises):
    grad_loss = []
    for i, param in enumerate(model.parameters()): # gradient for p
        grad_loss.append(torch.mul(noises[i], param.grad).view(-1))
    kl_loss.backward() 
    # gradient for p
    k = 0
    for i, param in enumerate(model.parameters()):
        t = len(param.grad.view(-1))
        g = torch.mul(grad_loss[i].view(-1), torch.exp(p.data[k:(k+t)])) 
        p.grad[k:(k+t)] += g
        p.grad[k:(k+t)] = p.grad[k:(k+t)].mean()*(torch.ones(t, device=p.device))
        k += t
    return

def initialization(model, w0decay=1.0):
    for param in model.parameters():
        param.data *= w0decay

    device = next(model.parameters()).device
    noises, noises_scaled, w0 = [], [], []
    for layer, param in enumerate(model.parameters()):
        w0.append(param.data.view(-1).detach().clone())
    num_layer = layer + 1
    w0 = torch.cat(w0) 
    p  = nn.Parameter(torch.ones(len(w0), device=device)*torch.log(w0.abs().mean()), requires_grad=True)
    #we = nn.Parameter(torch.ones(1, device=device)*torch.log(w0.abs().mean()), requires_grad=True)
    return w0, p, num_layer

def save_model(model, w0, p, epoch, prior, opt1, opt2, sch1,
                file_name, others=None, folder='logs/'):
    if os.path.isdir(folder) == False:
        try: 
            os.makedirs(folder)
        except:
            pass
    if sch1 is None:
        torch.save({
            'epoch': epoch, 'w0': w0,
            'model_state_dict': model.state_dict(),
            'p': p, 'prior': prior, 
            'opt1': opt1.state_dict(),
            'opt2': opt2.state_dict(),
            'others': others,
            }, folder+'/'+file_name+'.pt')
    else:
        torch.save({
            'epoch': epoch, 'w0': w0,
            'model_state_dict': model.state_dict(),
            'p': p, 'prior': prior, 
            'opt1': opt1.state_dict(),
            'opt2': opt2.state_dict(),
            'others': others,
            'sch1': sch1.state_dict()
            }, folder+'/'+file_name+'.pt')


###############################
########K Computation##########
###############################

def func_sum(x, gamma, error_list, error_mean_list):
    def func(err, err_mu):
        out = np.zeros((len(gamma),1))
        for r in range(len(gamma)):
           out[r]= -(np.mean(np.exp( np.longdouble(gamma[r]*(err_mu-err)) ))
                        -np.exp( np.longdouble(3*(gamma[r])**2*(x**2)/2) ))
        return out

    sum_output = 0
    for i in range(len(error_mean_list)):
        sum_output += func(error_list[i],np.mean(error_mean_list))
    return sum_output


def gen_output(model, prior, dataset, n, criterion):
    error_list = []
    error_mean_list = []

    device = next(model.parameters()).device
    train = torch.utils.data.DataLoader(dataset.train.dataset, batch_size=1000)
    # compute the output of the random model and store it in an array
    with torch.no_grad():
        for i in range(n):
            model1 = copy.deepcopy(model)
            # generating a random model/network from the prior distribtuion
            for param in model1.parameters():
                param.data += torch.randn(param.data.size(), device=device)*prior

            errors = []
            for batch in train:
                inputs, targets = (b.to(device) for b in batch)
                predictions = model1(inputs)
                err = criterion(predictions,targets)
                errors.extend(list(err.cpu().numpy()))

            error_list.append(errors)
            error_mean_list.append(np.mean(errors))
    return error_list, error_mean_list

def gen_gcn_output(model, prior, dataset, n, criterion, split):
    error_list = []
    error_mean_list = []

    train_mask = dataset.train_mask[:,split]
    device = next(model.parameters()).device
    # compute the output of the random model and store it in an array
    with torch.no_grad():
        for i in range(n):
            model1 = copy.deepcopy(model)
            # generating a random model/network from the prior distribtuion
            for param in model1.parameters():
                param.data += torch.randn(param.data.size(), device=device)*prior

            predictions = model1(dataset.x, dataset.edge_index)
            errors = criterion(predictions[train_mask], dataset.y[train_mask])
            errors = list(errors.cpu().numpy())

            error_list.append(errors)
            error_mean_list.append(np.mean(errors))

    return error_list, error_mean_list

def compute_K_sample(model, dataset, criterion, min_gamma, max_gamma, 
                        min_nu=-6, max_nu=-2,
                        gcn=False, split=0):

    def est_K(prior, x):
        # estimate k within a certain gamma range given prior
        gamma_grid = np.exp(np.linspace(np.log(min_gamma), np.log(max_gamma), 10))
        print('searching for K4....')
        if gcn:
            error_list, error_mean_list = gen_gcn_output(model, prior, dataset, 10, criterion, split)
        else:
            error_list, error_mean_list = gen_output(model, prior, dataset, 10, criterion)
            
        while min(func_sum(x, gamma_grid, error_list, error_mean_list)) < 0:
            x = x*1.1
        # while min(func_sum(x, gamma_grid, error_list, error_mean_list)) > 1e-4:
        #     x = x/1.1
        return x

    prior_list = np.exp(np.linspace(min_nu, max_nu, 2*(max_nu-min_nu) ))
    K_list = [1e-3]
    for i in range(len(prior_list)):
        K_list.append(est_K(prior_list[i], K_list[-1]))
    K_list = K_list[1:]

    # make lists monotonically increasing 
    ks, priors = [], [] 
    cur_max_k = 0
    for k, p in zip(K_list, prior_list):
        if k < cur_max_k:
            ks.append(cur_max_k)
            priors.append(p)
        else:
            ks.append(k)
            priors.append(p)
            cur_max_k = k

    return priors, ks

# def fun_K_auto(x, exp_prior_list, K_list):
#     n = len(exp_prior_list)
#     y = K_list[0] + torch.relu(x-exp_prior_list[0])*(K_list[1]-K_list[0])/(exp_prior_list[1]-exp_prior_list[0])
#     slope = (K_list[1]-K_list[0])/(exp_prior_list[1]-exp_prior_list[0])
#     for i in range(n-2):
#        slope = -slope + (K_list[i+2]-K_list[i+1])/(exp_prior_list[i+2]-exp_prior_list[i+1])
#        y += torch.relu(x-exp_prior_list[i+1])*slope
#     return y

def fun_K_auto(x,exp_prior_list,K_list):
    n = len(exp_prior_list)
    i = 0
    while x>exp_prior_list[i]:
        i +=1
        if i == n-1:
            break
    if i==0:
        fa = K_list[0]+exp_prior_list[0]
        fb = K_list[0]
        a = 0
        b = exp_prior_list[0]
    else:
        fa = K_list[i-1]
        fb = K_list[i]
        a = exp_prior_list[i-1]
        b = exp_prior_list[i]
    return (b-x)/(b-a)*fa + (x-a)/(b-a)*fb


# def resume(model, file):
#     checkpoint = torch.load(file+'_s1.pt')
#     model.load_state_dict(checkpoint['model_state_dict'])
#     start_epoch = checkpoint['epoch']
#     others = checkpoint['others']
#     p = checkpoint['p']
#     w0 = checkpoint['w0']
#     prior = checkpoint['prior']
#     k = checkpoint['k']
#     return start_epoch, p, w0, prior, k, others 

# def weight_decay_layer(model, w0):
#     # noise injection
#     k, weights = 0, []
#     for i, param in enumerate(model.parameters()):
#         t = len(param.view(-1))
#         weights.append(torch.norm(param.view(-1)-w0[k:(k+t)])**2)
#         k += t
#     return weights

# def weight_decay_layerp(model, w0):
#     k, weights = 0, 0
#     for i, param in enumerate(model.parameters()):
#         t = len(param.view(-1))
#         if torch.norm(w0[k:(k+t)])>1e-6:
#             alpha = (param.view(-1)*w0[k:(k+t)]).sum()/torch.norm(w0[k:(k+t)])**2
#         else:
#             alpha = 1
#         weights += torch.norm(param.view(-1)-alpha*w0[k:(k+t)])**2
#         k += t
#     return weights

# def get_kl_term_layer(K4, weights, p, model):
#     k, kl = 0, 0.0
#     for i, param in enumerate(model.parameters()):
#         t = len(param.view(-1))
#         denominator = weights[i] + torch.norm(torch.exp(p[k:(k+t)]))**2

#         we = torch.clip(t/denominator, max=1e3)
#         kl += -t*torch.log(we) -2*p[k:(k+t)].sum() -t +we*denominator

#         k += t 
#     loss2 = K4*( 6*(0.5*kl+30) /5e4 )**0.5
#     return loss2, kl, 1/we**0.5

# def get_children(model: torch.nn.Module):
#     children = list(model.children())
#     flatt_children = []
#     if children == []:
#         # if model has no children; model is last child! :O
#         return model
#     else:
#        # look for children from children... to the last child!
#        for child in children:
#             try:
#                 flatt_children.extend(get_children(child))
#             except TypeError:
#                 flatt_children.append(get_children(child))
#     return flatt_children
