import copy
import sys
from copy import deepcopy

import numpy as np
import torch
from torch.optim.optimizer import Optimizer
from tqdm import tqdm

sys.path.append(".")
from src.tools.sharpness_tools.math_utils import compute_loss
from src.tools.sharpness_tools.utils import get_param_dim, get_device, load_weights
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast


class EntropySGD(Optimizer):
    def __init__(self, params, config={}):

        defaults = dict(lr=0.01, momentum=0, damp=0,
                        weight_decay=0, nesterov=True,
                        L=0, eps=1e-4, g0=1e-2, g1=0)
        for k in defaults:
            if config.get(k, None) is None:
                config[k] = defaults[k]

        super(EntropySGD, self).__init__(params, config)
        self.config = config
        self.scalar = GradScaler()

    def step(self, closure=None, model=None, criterion=None):
        assert (closure is not None) and (model is not None) and (criterion is not None), \
            'attach closure for Entropy-SGD, model and criterion'
        loss = closure()

        c = self.config
        lr = c['lr']
        mom = c['momentum']
        wd = c['weight_decay']
        damp = c['damp']
        nesterov = c['nesterov']
        L = int(c['L'])
        eps = c['eps']
        g0 = c['g0']
        g1 = c['g1']

        params = self.param_groups[0]['params']

        state = self.state
        # initialize
        if not 't' in state:
            state['t'] = 0
            state['wc'], state['mdw'] = [], []
            for w in params:
                state['wc'].append(deepcopy(w.data))
                state['mdw'].append(deepcopy(w.grad.data))

            state['langevin'] = dict(mw=deepcopy(state['wc']),
                                     mdw=deepcopy(state['mdw']),
                                     eta=deepcopy(state['mdw']),
                                     lr=0.1,
                                     beta1=0.75)

        lp = state['langevin']
        for i, w in enumerate(params):
            state['wc'][i].copy_(w.data)
            lp['mw'][i].copy_(w.data)
            lp['mdw'][i].zero_()
            lp['eta'][i].normal_()

        state['debug'] = dict(wwpd=0, df=0, dF=0, g=0, eta=0)
        llr, beta1 = lp['lr'], lp['beta1']
        g = g0 * (1 + g1) ** state['t']

        for i in range(L):
            loss = closure()
            for wc, w, mw, mdw, eta in zip(state['wc'], params, \
                                           lp['mw'], lp['mdw'], lp['eta']):
                dw = w.grad.data

                if wd > 0:
                    dw.add_(w.data, alpha=wd)
                if mom > 0:
                    mdw.mul_(mom).add_(dw, alpha=1 - damp)
                    if nesterov:
                        dw.add_(mdw, alpha=mom)
                    else:
                        dw = mdw

                # add noise
                eta.normal_()
                dw.add_(wc - w.data, alpha=-g).add_(eta, alpha=eps / np.sqrt(0.5 * llr))

                # update weights
                w.data.add_(dw, alpha=-llr)
                mw.mul_(beta1).add_(w.data, alpha=1 - beta1)

        if L > 0:
            # copy model back
            for i, w in enumerate(params):
                w.data.copy_(state['wc'][i])
                w.grad.data.copy_(w.data - lp['mw'][i])

        for w, mdw, mw in zip(params, state['mdw'], lp['mw']):
            dw = w.grad.data

            if wd > 0:
                dw.add_(w.data, alpha=wd)
            if mom > 0:
                mdw.mul_(mom).add_(dw, alpha=1 - damp)
                if nesterov:
                    dw.add_(mdw, alpha=mom)
                else:
                    dw = mdw

            w.grad.data = dw
            # data.add_(-lr, dw)

        return loss


def entropy(model, data_loader, gamma, mcmc_itr):
    scalar = GradScaler()
    with torch.no_grad():
        theta_star = [p.data.clone() for p in model.parameters()]
        model_dim = get_param_dim(model)

    out = []
    for _ in tqdm(range(mcmc_itr)):
        for mp, p in zip(model.parameters(), theta_star):
            mp.data.copy_(p + torch.zeros(p.shape, device=mp.data.device).normal_(0, 1 / gamma))
        out += [torch.Tensor([compute_loss(model, data_loader, scalar)[0]])]

    load_weights(model, theta_star)
    return -(torch.logsumexp(torch.cat(out), 0, False) + np.log(1 / mcmc_itr) + model_dim / 2 * np.log(
        2 * np.pi) - np.log(gamma ** 0.5)).item()


def entropy_grad(model, data_loader):
    scalar = GradScaler()
    criterion = torch.nn.CrossEntropyLoss()
    model_dim = get_param_dim(model)
    device = get_device(model)
    opt = EntropySGD(model.parameters(), config=dict(lr=0.1, momentum=0.0,
                                                     nesterov=False, weight_decay=0.0, L=20, eps=1e-4,
                                                     g0=1e-4, g1=1e-3))
    all_grads = torch.empty((len(data_loader), model_dim))
    for i, (inputs, targets) in enumerate(tqdm(data_loader)):
        inputs = inputs.to(device)
        targets = targets.to(device)

        def closure():
            opt.zero_grad()
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)
            scalar.scale(loss).backward()
            return loss

        opt.step(closure, model, criterion)

        grads = None
        for p in model.parameters():
            if grads is None:
                grads = copy.deepcopy(p.grad.data).reshape(-1)
            else:
                grads = torch.cat([grads, copy.deepcopy(p.grad.data).reshape(-1)])
        all_grads[i, :] = grads

    return torch.mean(torch.norm(all_grads, dim=1)).item()
