import torch
import torch.nn as nn

import math, os, copy
import numpy as np
from scipy.optimize import bisect

import warnings

from torch.utils.data import DataLoader
import sys

sys.path.append('./')
from myDataLoader import bucket_dataset

warnings.filterwarnings("ignore")


def bounded_cross_entropy(x, y, eps=-1):
    pred = F.log_softmax(x, dim=-1)
    pred = torch.log(torch.exp(pred) + math.exp(eps))
    return F.nll_loss(pred, y, reduce=False)


def evaluation(model, criterion, dataset, log, lr):
    model.eval()
    device = next(model.parameters()).device
    log.eval(len_dataset=len(dataset.test))
    losses, count = 0, 0
    with torch.no_grad():
        for batch in dataset.test:
            inputs, targets = (b.to(device) for b in batch)
            predictions = model(inputs)
            loss = criterion(predictions, targets)
            correct = torch.argmax(predictions, 1) == targets
            log(None, loss.cpu(), correct.cpu(), lr)
            losses += loss.sum().item()
            count += inputs.shape[0]
    return losses / count


def noise_injection(model, p):
    k = 0
    device = next(model.parameters()).device
    noises, noises_scaled = [], []
    for i, param in enumerate(model.parameters()):
        if not param.requires_grad: continue
        t = len(param.view(-1))
        local_noise = torch.clip(torch.randn(param.data.size(), device=device), min=-2, max=2)
        noises.append(local_noise)
        scaled_local_noise = 0
        if torch.is_tensor(p) and p.dim() > 0:
            # print("p {}\tparam {}\tnoise {}".format(p.dim(),param.data.size(),local_noise.data.size()))
            scaled_local_noise = torch.mul(torch.reshape(torch.exp(p[k:(k + t)]).data, param.data.size()), local_noise)
            noises_scaled.append(scaled_local_noise)
        else:
            scaled_local_noise = local_noise * p
            noises_scaled.append(scaled_local_noise)

        param.data += scaled_local_noise
        k += t
    return noises, noises_scaled


def rm_injected_noises(model, noises_scaled):
    injected_noise = copy.deepcopy(noises_scaled)
    for i, param in enumerate(model.parameters()):
        if not param.requires_grad: continue
        param.data -= injected_noise[0]
        injected_noise = injected_noise[1:]

    return


def weight_decay(model, w0):
    k, weights = 0, 0
    for i, param in enumerate(model.parameters()):
        if not param.requires_grad: continue
        t = len(param.view(-1))
        # print(param.size(),len(w0))
        weights += torch.norm(param.view(-1) - w0[k:(k + t)]) ** 2
        k += t
    return weights


def get_kl_term(weight_decay, p, samples, we=None, layers=1, maxwe=1e5):
    denominator = weight_decay + torch.norm(torch.exp(p)) ** 2
    if we is None:
        we = torch.clip(len(p) / denominator, max=maxwe)
    kl = 0.5 * ((-2 * p).sum() - len(p) * torch.log(we) - len(p) + we * denominator)
    return (6 * (kl + 60 * layers) / samples) ** 0.5, kl, 1 / we ** 0.5


def get_kl_term_with_b(weight_decay, p, b):
    d = len(p)
    KL = (torch.exp(-2 * b.double()) * torch.exp(2 * (p).double()).sum() / d -
          (2 * (p).double().sum() / d - 2 * b.double() + 1))
    return (KL * d + weight_decay * torch.exp(-2 * b)) / 2


def kl_term_backward_mean(kl_loss, model, p, noises):
    grad_loss = []
    copy_noise = copy.deepcopy(noises)
    for i, param in enumerate(model.parameters()):  # gradient for p
        if not param.requires_grad: continue

        grad_loss.append(torch.mul(copy_noise[0], param.grad).view(-1))
        copy_noise = copy_noise[1:]
    kl_loss.backward()
    # gradient for p
    k = 0
    # copy_noise = copy.deepcopy(noises)
    for i, param in enumerate(model.parameters()):
        if not param.requires_grad: continue
        t = len(param.grad.view(-1))
        g = torch.mul(grad_loss[0].view(-1), torch.exp(p.data[k:(k + t)]))
        grad_loss = grad_loss[1:]
        p.grad[k:(k + t)] += g
        p.grad[k:(k + t)] = p.grad[k:(k + t)].mean() * (torch.ones(t, device=p.device))
        k += t
    return


def kl_term_backward(kl_loss, model, p, noises):
    grad_loss = []
    copy_noise = copy.deepcopy(noises)
    for i, param in enumerate(model.parameters()):  # gradient for p
        if not param.requires_grad: continue

        grad_loss.append(torch.mul(copy_noise[0], param.grad).view(-1))
        copy_noise = copy_noise[1:]
    kl_loss.backward()
    # gradient for p
    k = 0
    # copy_noise = copy.deepcopy(noises)
    for i, param in enumerate(model.parameters()):
        if not param.requires_grad: continue
        t = len(param.grad.view(-1))
        g = torch.mul(grad_loss[0].view(-1), torch.exp(p.data[k:(k + t)]))
        grad_loss = grad_loss[1:]
        p.grad[k:(k + t)] += g
        k += t
    return


def initialization(model, w0decay=1.0):
    for param in model.parameters():
        if not param.requires_grad: continue
        param.data *= w0decay

    device = next(model.parameters()).device
    noises, noises_scaled, w0 = [], [], []
    for layer, (n, param) in enumerate(model.named_parameters()):
        if not param.requires_grad: continue
        w0.append(param.data.view(-1).detach().clone())
        # print(param.data.size())
    num_layer = layer + 1
    w0 = torch.cat(w0)
    # print(w0)

    # p = nn.Parameter(torch.ones(len(w0), device=device) * torch.log(w0.abs().mean()), requires_grad=True)
    # OR
    # p = nn.Parameter(torch.ones(len(w0), device=device) * -np.log(10), requires_grad=True)
    p = nn.Parameter(torch.ones(len(w0), device=device) * torch.log(w0.abs().mean()), requires_grad=True)
    # we = nn.Parameter(torch.ones(1, device=device)*torch.log(w0.abs().mean()), requires_grad=True)
    return w0, p, num_layer


def save_model(model, w0, p, epoch, prior, opt1, opt2, sch1,
               file_name, others=None, folder='logs/'):
    if os.path.isdir(folder) == False:
        try:
            os.makedirs(folder)
        except:
            pass
    if sch1 is None:
        torch.save({
            'epoch': epoch, 'w0': w0,
            'model_state_dict': model.state_dict(),
            'p': p, 'prior': prior,
            'opt1': opt1.state_dict(),
            'opt2': opt2.state_dict(),
            'others': others,
        }, folder + '/' + file_name + '.pt')
    else:
        torch.save({
            'epoch': epoch, 'w0': w0,
            'model_state_dict': model.state_dict(),
            'p': p, 'prior': prior,
            'opt1': opt1.state_dict(),
            'opt2': opt2.state_dict(),
            'others': others,
            'sch1': sch1.state_dict()
        }, folder + '/' + file_name + '.pt')


######################################################################
######################################################################
######################################################################

def func_sum(x, gamma, error_list, error_mean_list):
    def func(err, err_mu):
        out = np.zeros((len(gamma), 1))
        for r in range(len(gamma)):
            out[r] = -(np.mean(np.exp(np.longdouble(gamma[r] * (err_mu - err))))
                       - np.exp(np.longdouble(3 * (gamma[r]) ** 2 * (x ** 2) / 2)))
        return out

    sum_output = 0
    for i in range(len(error_mean_list)):
        sum_output += func(error_list[i], np.mean(error_mean_list))
    return sum_output


def gen_output_transformer(args, model, prior, dataset, n):
    error_list = []
    error_mean_list = []

    # device = next(model.parameters()).to(args.device)
    # train = torch.utils.data.DataLoader(dataset.train.dataset, batch_size=1000)
    # compute the output of the random model and store it in an array
    with torch.no_grad():
        for i in range(n):
            model1 = copy.deepcopy(model)
            # generating a random model/network from the prior distribtuion
            for param in model1.parameters():
                if not param.requires_grad: continue
                param.data += torch.randn(param.data.size(), device=args.device) * prior

            errors = []
            train_dataset = bucket_dataset(args, args.train_data)
            train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=lambda x: x)
            for batch in train_dataloader:
                # inputs, targets = (b.to(args.device) for b in batch)
                _, _, _, loss = model1(batch, True)
                # print(loss)
                # err = criterion(predictions, targets)
                errors.extend(list(loss.cpu().numpy()))

            error_list.append(errors)
            error_mean_list.append(np.mean(errors))
    return error_list, error_mean_list


def compute_K_sample_transformer(args, model, dataset, min_gamma, max_gamma):
    def est_K(prior, x):
        # estimate k within a certain gamma range given prior
        gamma_grid = np.exp(np.linspace(np.log(min_gamma), np.log(max_gamma), 10))
        print('searching for K4....')
        error_list, error_mean_list = gen_output_transformer(args, model, prior, dataset, 10)
        while min(func_sum(x, gamma_grid, error_list, error_mean_list)) < -1e-20:
            x = x * 1.5
        while min(func_sum(x, gamma_grid, error_list, error_mean_list)) > 0:
            x = x / 1.1
        return x

    prior_list = np.exp(np.linspace(-6, -2, 8))
    K_list = [1e-3]
    for i in range(len(prior_list)):
        K_list.append(est_K(prior_list[i], K_list[-1]))
    K_list = K_list[1:]

    # make lists monotonically increasing
    ks, priors = [], []
    cur_max_k = 0
    for k, p in zip(K_list, prior_list):
        if k < cur_max_k:
            ks.append(cur_max_k)
            priors.append(p)
        else:
            ks.append(k)
            priors.append(p)
            cur_max_k = k

    return priors, ks


def compute_K_sample(model, dataset, criterion, min_gamma, max_gamma):
    def est_K(prior, x):
        # estimate k within a certain gamma range given prior
        gamma_grid = np.exp(np.linspace(np.log(min_gamma), np.log(max_gamma), 10))
        print('searching for K4....')
        error_list, error_mean_list = gen_output(model, prior, dataset, 10, criterion)
        while min(func_sum(x, gamma_grid, error_list, error_mean_list)) < -1e-20:
            x = x * 1.5
        while min(func_sum(x, gamma_grid, error_list, error_mean_list)) > 0:
            x = x / 1.1
        return x

    prior_list = np.exp(np.linspace(-6, -2, 8))
    K_list = [1e-3]
    for i in range(len(prior_list)):
        K_list.append(est_K(prior_list[i], K_list[-1]))
    K_list = K_list[1:]

    # make lists monotonically increasing
    ks, priors = [], []
    cur_max_k = 0
    for k, p in zip(K_list, prior_list):
        if k < cur_max_k:
            ks.append(cur_max_k)
            priors.append(p)
        else:
            ks.append(k)
            priors.append(p)
            cur_max_k = k

    return priors, ks


def fun_K_auto(x, exp_prior_list, K_list):
    n = len(exp_prior_list)
    y = K_list[0] + torch.relu(x - exp_prior_list[0]) * (K_list[1] - K_list[0]) / (
            exp_prior_list[1] - exp_prior_list[0])
    slope = (K_list[1] - K_list[0]) / (exp_prior_list[1] - exp_prior_list[0])
    for i in range(n - 2):
        slope = -slope + (K_list[i + 2] - K_list[i + 1]) / (exp_prior_list[i + 2] - exp_prior_list[i + 1])
        y += torch.relu(x - exp_prior_list[i + 1]) * slope
    return y


def fun_K_auto_new(x, exp_prior_list, K_list):
    n = len(exp_prior_list)
    i = 0
    while x > exp_prior_list[i]:
        i += 1
        if i == n - 1:
            break
    if i == 0:
        fa = K_list[0] + exp_prior_list[0]
        fb = K_list[0]
        a = 0
        b = exp_prior_list[0]
    else:
        fa = K_list[i - 1]
        fb = K_list[i]
        a = exp_prior_list[i - 1]
        b = exp_prior_list[i]
    return (b - x) / (b - a) * fa + (x - a) / (b - a) * fb
