import numpy as np
import torch
import pytorch_lightning as pl
import src.models.kl_Tmap as T_system
from torch.optim.lr_scheduler import StepLR as StepLR
from torch.autograd import Variable
from tqdm import tqdm
from IPython.display import clear_output


class Base_gmap(pl.LightningModule):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def map_forward(self, input_data):
        if ~input_data.requires_grad:
            input_data.requires_grad = True
        g_of_x = self.map_t(input_data).sum()
        Tx = torch.autograd.grad(
            g_of_x, input_data, create_graph=True)[0]
        return Tx

    def map_forward_nograd(self, input_data):
        if ~input_data.requires_grad:
            input_data.requires_grad = True
        g_of_x = self.map_t(input_data).sum()
        Tx = torch.autograd.grad(
            g_of_x, input_data, create_graph=False, retain_graph=False)[0]
        return Tx

    def nn_loss(self):
        return -0.1 * compute_constraint_loss(
            self.g_positive_params)


def compute_constraint_loss(list_of_params):
    loss_val = 0
    for p in list_of_params:
        loss_val += torch.relu(-p).pow(2).sum()
    return loss_val


def id_pretrain_model(
        model, sampler, lr=1e-3, n_max_iterations=2000, batch_size=1024, loss_stop=1e-5, verbose=True):
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-8)
    for it in tqdm(range(n_max_iterations), disable=not verbose):
        X = sampler.sample_n(batch_size)
        if len(X.shape) == 1:
            X = X.view(-1, 1)
        X.requires_grad_(True)
        loss = torch.nn.functional.mse_loss(model.push(X), X)
        loss.backward()

        opt.step()
        opt.zero_grad()
        model.convexify()

        if verbose:
            if it % 100 == 99:
                clear_output(wait=True)
                print('Loss:', loss.item())

            if loss.item() < loss_stop:
                clear_output(wait=True)
                print('Final loss:', loss.item())
                break
    return model


class Gaussian_nabla_gmap(Base_gmap):
    def __init__(self, *kargs, **kwargs):
        super(Gaussian_nabla_gmap, self).__init__(*kargs, **kwargs)
        #!tmp
        self.g_positive_params = []
        for p in list(self.map_t.parameters()):
            if hasattr(p, 'be_positive'):
                self.g_positive_params.append(p)
        # for p in self.map_t.parameters():
        #     p.data = torch.randn(
        #         p.shape, dtype=torch.float32) / np.sqrt(float(self.cfg.NUM_NEURON_map))

    def training_step(self, real_data, _):
        optimizer_map, optimizer_h = self.optimizers()
        x_i = real_data[:, :, 0]
        x_i = Variable(x_i, requires_grad=True)
        y_data = real_data[:, :, 1]

        h_ot_loss_value_batch = 0
        optimizer_map.zero_grad()
        ######################################################
        #                Inner Loop Begin                   #
        ######################################################
        for inner_iter in range(1, self.cfg.N_inner_ITERS + 1):
            optimizer_h.zero_grad()
            Tx = x_i if self.cfg.debug_h else self.map_forward(x_i)
            loss_h = self._loss_h_function(y_data, Tx)
            h_ot_loss_value_batch += loss_h
            loss_h.backward(retain_graph=True)

            # ! update h
            optimizer_h.step()
            # Just for the last iteration keep the gradient on f intact
            if inner_iter != self.cfg.N_inner_ITERS:
                optimizer_map.zero_grad()

        ######################################################
        #                Inner Loop Ends                     #
        ######################################################
        norm_Tx = Tx.pow(2).sum(dim=1).mean() / 2
        inner_prod = torch.dot(
            x_i.reshape(-1), Tx.reshape(-1)) / self.cfg.BATCH_SIZE

        w2_loss = (- norm_Tx + inner_prod) / self.cfg.step_a

        #! Below is some choice for forcing it to be convex or add a penalty.
        nn_loss = self.nn_loss()
        if self.cfg.mu_equal_q:
            remaining_g_loss = w2_loss + nn_loss
            log_loss = torch.zeros_like(w2_loss)
        else:
            log_loss = self.log_total_loss(Tx)
            remaining_g_loss = log_loss + w2_loss + nn_loss

        remaining_g_loss.backward()
        for p in list(self.map_t.parameters()):
            p.grad.copy_(-p.grad)
        # ! update T
        optimizer_map.step()
        h_ot_loss_value_batch /= (self.cfg.N_inner_ITERS)

        self.log_dict({'h_ot_loss': h_ot_loss_value_batch,
                       'log_loss': log_loss, 'w2_loss': w2_loss, 'nn_loss': nn_loss}, prog_bar=True)
        #!tmp
        # if self.cfg.mu_equal_q:
        #     remaining_g_loss = w2_loss
        #     log_loss = torch.zeros_like(w2_loss)
        # else:
        #     log_loss = self.log_total_loss(Tx)
        #     remaining_g_loss = log_loss + w2_loss
        # remaining_g_loss.backward()
        # for p in list(self.map_t.parameters()):
        #     p.grad.copy_(-p.grad)

        # # ! update T
        # optimizer_map.step()
        # self.map_t.convexify()

        # h_ot_loss_value_batch /= (self.cfg.N_inner_ITERS)
        # total_loss_value_batch = h_ot_loss_value_batch + remaining_g_loss

        # self.log_dict({'epoch': self.epoch, 'h_ot_loss': h_ot_loss_value_batch,
        #                'log_loss': log_loss, 'w2_loss': w2_loss, 'total_loss': total_loss_value_batch}, prog_bar=True)
        return None


class GM_nabla_gmap(Gaussian_nabla_gmap, T_system.GM_Tmap):
    pass
