import torch
import torch.nn as nn
from torch.autograd import grad

from tqdm import trange
import numpy as np

from src.attacks import pgd_rand
from src.context import ctx_noparamgrad_and_eval, ctx_eval
from src.utils_general import ep2itr
from src.soar import soar

def data_init(init, X, y, model):
    if init == "rand":
        delta = torch.empty_like(X.detach(), requires_grad=False).uniform_(-8./255.,8./255.)
        delta.data = (X.detach() + delta.detach()).clamp(min = 0, max = 1.0) - X.detach()
    elif init == "pgd1":
        with ctx_noparamgrad_and_eval(model):
            param = {"ord":np.inf, "epsilon": 4./255., "alpha":4./255., "num_iter": 1, "restart": 1}
            delta = pgd_rand(**param).generate(model,X,y)
    elif init == "none":
        delta = torch.zeros_like(X.detach(), requires_grad=False)

    return delta

def train_standard(epoch, loader, model, opt, device):
    total_loss, total_correct = 0., 0.
    curr_itr = ep2itr(epoch, loader)
    with trange(len(loader)) as t:
        for X, y in loader:
            model.train()
            X, y = X.to(device), y.to(device)

            yp = model(X)
            loss = nn.CrossEntropyLoss()(yp, y)

            opt.zero_grad()
            loss.backward()
            opt.step()

            batch_correct = (yp.argmax(dim=1) == y).sum().item()
            total_correct += batch_correct

            batch_acc = batch_correct / X.shape[0]
            total_loss += loss.item() * X.shape[0]

            t.set_postfix(loss=loss.item(),
                          acc='{0:.2f}%'.format(batch_acc*100))
            t.update()
            curr_itr += 1

    acc = total_correct / len(loader.dataset) * 100
    total_loss = total_loss / len(loader.dataset)

    return acc, total_loss

def train_soar(epoch, loader, model, args, opt, device):
    len_data = len(loader.dataset)
    total_loss, total_correct, total_reg= 0.,0.,0.
    
    curr_itr = ep2itr(epoch, loader)
    with trange(len(loader)) as t:
        for X,y in loader:
            model.train()
            X,y = X.to(device), y.to(device)
            
            delta = data_init(args.s_init, X, y, model)
            X_delta = X.detach() + delta.detach()
            X_delta.requires_grad = True

            yp_delta = model(X_delta)
            loss_delta = nn.CrossEntropyLoss()(yp_delta, y)
            _dim = X.shape[1]*X.shape[2]*X.shape[3]
            
            reg = soar(X_delta, y, loss_delta, args.s_clip, args.s_step_size, model, device)
            reg_delta = 0.5 * (args.s_eps**2 *  _dim + 1) * reg
            
            loss_reg = loss_delta+reg_delta
            opt.zero_grad()
            loss_reg.backward()
            opt.step()

            batch_correct = (yp_delta.argmax(dim = 1) == y).sum().item()
            total_correct += batch_correct

            batch_acc = batch_correct / X.shape[0]
            total_loss += loss_delta.item() * X.shape[0]
            total_reg += reg.detach().item() * X.shape[0]
            
            t.set_postfix(loss = loss_delta.item(),
                          reg = reg.item(),
                          acc = "{0:.2f}%".format(batch_acc*100))
            t.update()
            
            curr_itr += 1

        acc = total_correct / len_data * 100
        total_loss = total_loss / len_data
        total_reg = total_reg / len_data
        
    return acc, total_loss, total_reg
