import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.autograd import Variable

def compute_loss_weights_simple(loss_orth, loss_cons, loss_smap, loss_sres, nbatch, kappa=1.0e1, eps=1.0e-6):
    w_orth = 1.0
    w_smap = 1.0
    w_cons = 1.0
    loss_orth_numpy = (0.0 + loss_orth).cpu().detach().numpy()
    loss_cons_numpy = (0.0 + loss_cons).cpu().detach().numpy()
    loss_smap_numpy = (0.0 + loss_smap).cpu().detach().numpy()
    loss_sres_numpy = (0.0 + loss_sres).cpu().detach().numpy()

    w_cons = np.exp(-np.sqrt(nbatch) * kappa * loss_orth_numpy)
    w_sres = np.exp(-np.sqrt(nbatch) * kappa * max(loss_cons_numpy, loss_orth_numpy))
    w_smap = np.exp(-np.sqrt(nbatch) * kappa * max(loss_cons_numpy, loss_orth_numpy, loss_sres_numpy))
    # w_smap = np.exp(-np.sqrt(nbatch) * kappa * max(loss_cons_numpy, loss_orth_numpy))
    # w_sres = np.exp(-np.sqrt(nbatch) * kappa * max(loss_cons_numpy, loss_orth_numpy, loss_smap_numpy))

    w_summ = w_orth + w_smap + w_cons + w_sres
    w_orth = w_orth / w_summ
    w_cons = w_cons / w_summ
    w_smap = w_smap / w_summ
    w_sres = w_sres / w_summ
    return (w_orth, w_cons, w_smap, w_sres)


def compute_neural_network_parameters(nnet, nepoch, nbatch, lr, w, x, f):
    nnet.fc_smat.requires_grad_(True)
    nnet.regressor.requires_grad_(True)
    print('nbatch = ' + str(nbatch))
    optimizer = optim.Adam(nnet.parameters(), lr=lr, weight_decay=0.0)
    loss_data = np.zeros((nepoch + 1, 5))
    for kepoch in range(nepoch):
        optimizer.zero_grad()
        loss_orth, loss_cons, loss_smap, loss_sres = nnet(w, Variable(x), f)
        w_orth, w_cons, w_smap, w_sres = compute_loss_weights_simple(loss_orth, loss_cons, loss_smap, loss_sres, nbatch)
        loss = w_orth * loss_orth + w_cons * loss_cons + w_smap * loss_smap + w_sres * loss_sres
        loss_data[kepoch, 0] = loss_orth.item() 
        loss_data[kepoch, 1] = loss_cons.item()
        loss_data[kepoch, 2] = loss_smap.item()
        loss_data[kepoch, 3] = loss_sres.item()
        loss_data[kepoch, 4] = loss.item()
        loss.backward()
        print('kepoch = ' + str(kepoch) + ' w_orth = ' + str(w_orth))
        print('kepoch = ' + str(kepoch) + ' w_cons = ' + str(w_cons))
        print('kepoch = ' + str(kepoch) + ' w_smap = ' + str(w_smap))
        print('kepoch = ' + str(kepoch) + ' w_sres = ' + str(w_sres))
        print('kepoch = ' + str(kepoch) + ' loss_orth.item() = ' + str(loss_orth.item()))
        print('kepoch = ' + str(kepoch) + ' loss_cons.item() = ' + str(loss_cons.item()))
        print('kepoch = ' + str(kepoch) + ' loss_smap.item() = ' + str(loss_smap.item()))
        print('kepoch = ' + str(kepoch) + ' loss_sres.item() = ' + str(loss_sres.item()))
        print('kepoch = ' + str(kepoch) + ' loss.item() = ' + str(loss.item()))
        optimizer.step()

    loss_orth, loss_cons, loss_smap, loss_sres = nnet(w, Variable(x), f)
    w_orth, w_cons, w_smap, w_sres = compute_loss_weights_simple(loss_smap, loss_cons, loss_orth, loss_sres, nbatch)
    loss = w_orth * loss_orth + w_cons * loss_cons + w_smap * loss_smap + w_sres * loss_sres
    loss_data[-1, 0] = loss_orth.item() 
    loss_data[-1, 1] = loss_cons.item()
    loss_data[-1, 2] = loss_smap.item()
    loss_data[-1, 3] = loss_sres.item()
    loss_data[-1, 4] = loss.item()
    return loss_data




















