import numpy as np
import torch
import time
import jacinle.io as io
import pytorch_lightning as pl
from src.models.porus_media_gmap import Porus_gmap
from src.models.porus_media_Tmap import Porus_Tmap
import src.utils.plot_utils as PLU
import src.datamodules.generate_data as g_data
import matplotlib.pyplot as plt
import copy


class Aggreg_1step_gmap(Porus_gmap):
    def training_step(self, real_data, _):
        optimizer_t, _ = self.optimizers()
        Pk_data = real_data[:, :, 0]

        optimizer_t.zero_grad()
        Tx = self.map_forward(Pk_data)

        norm_Tx = Tx.pow(2).sum(dim=1).mean() / 2
        inner_prod = torch.dot(
            Pk_data.reshape(-1), Tx.reshape(-1)) / self.cfg.BATCH_SIZE

        w2_loss = (norm_Tx - inner_prod) / self.cfg.step_a
        # *nn_loss is 0 if it's normal T otherwise
        nn_loss = self.nn_loss()
        aggreg_loss = self.aggreg_loss(Tx, real_data[:, :, 2])

        remaining_map_loss = w2_loss + aggreg_loss - nn_loss
        remaining_map_loss.backward()
        # ! update T
        optimizer_t.step()

        self.log_dict(
            {'epoch': self.epoch, 'aggreg_loss': aggreg_loss, 'w2_loss': w2_loss, 'nn_loss': nn_loss}, prog_bar=True)
        return None

    def aggreg_loss(self, Pk_1_pushed, Pk_2):
        # * Now I fix the function W, I may need to change it later.
        Pk_2_pushed = self.map_forward(Pk_2)

        if self.cfg.INPUT_DIM == 2:
            inside_w = Pk_1_pushed.reshape(self.cfg.BATCH_SIZE, 1, -1) - \
                Pk_2_pushed.reshape(1, self.cfg.BATCH_SIZE, -1)
            squared_norm = (inside_w**2).sum(axis=2).reshape(-1)
            expec_w = squared_norm**2 / 4 - squared_norm / 2
        elif self.cfg.INPUT_DIM == 1:
            inside_w = Pk_1_pushed.reshape(self.cfg.BATCH_SIZE, 1) - \
                Pk_2_pushed.reshape(1, self.cfg.BATCH_SIZE)
            norm = inside_w.abs().reshape(-1)
            expec_w = norm**2 / 2 - (norm + 1e-6).log()
        return expec_w.mean()

    def on_train_epoch_end(self):
        Tx = self.map_forward_nograd(self.train_data[:, :, 0].cuda())
        self._basic_plot(Tx)
        if self.epoch % self.cfg.epochs == 0:
            torch.save(self.map_t.state_dict(), self.P_save_path + f'/map_{self.k}.pt')
            # self.plot_2d_density()
            if self.cfg.INPUT_DIM == 1:
                self.plot_1d_density()

            P_k = self.iterate_dataloader()
            io.dump(self.P_save_path + f'/P_{self.k-1}.pt', P_k)
        self.epoch += 1

    def plot_2d_epoch(self, P_k, before_epoch='after'):
        if self.epoch == 1:
            PLU.sns_jointplot_alone(
                self.train_data[:, 0, 0], self.train_data[:, 1, 0], self.image_save_path + f'/epoch0.png')

        Pk_plot = P_k.detach().cpu().numpy()
        Pk_hist = plt.hist2d(
            Pk_plot[:, 0], Pk_plot[:, 1], bins=self.cfg.num_grid,
            range=[[-self.cfg.plot_size, self.cfg.plot_size]] * 2, density=True, alpha=0.5)[0]
        plt.savefig(self.image_save_path + f'/epoch{self.epoch}_Pk_2d_{before_epoch}.png')
        x_grid, y_grid, _ = PLU.grid_NN_2_generator(
            self.handle.ax_param.num_grid, self.handle.ax_param.left_place, self.handle.ax_param.right_place)
        PLU.surface_alone(
            x_grid, y_grid, Pk_hist, ax_params=self.handle.ax_param,
            save_path=self.image_save_path + f'/epoch{self.epoch}_Pk_surface_{before_epoch}.png')

    def plot_gt_rho_aggreg_1d(self, x_plot):
        inside_sqrt = torch.relu(2 - x_plot**2)
        gt_rho = inside_sqrt.sqrt() / np.pi
        self.P0_nml_c = (gt_rho.sum() * (2 * self.cfg.plot_size / self.cfg.num_grid))
        gt_rho /= self.P0_nml_c
        plt.plot(x_plot, gt_rho, c='C1', label='steady density')

    def iterate_dataloader(self):
        self.k += 1
        P_iterate = self.new_Pk_generator()
        return P_iterate


class Aggreg_2step_base(pl.LightningModule):
    def __init__(self, *kargs, **kwargs):
        super(Aggreg_2step_base, self).__init__(*kargs, **kwargs)
        self.half_k = 0

    def on_train_epoch_start(self) -> None:
        if self.cfg.epochs == 1 or self.epoch % self.cfg.epochs == 1:
            # print(self.train_data[:, :, 0].mean(), self.train_data[:, :, 0].var())
            print("begin mapping data")
            t_tr_strt = time.perf_counter()
            P_k = self.map_forward_gd(self.train_data[:, :, 0])
            print("finish mapping data takes time ", time.perf_counter() - t_tr_strt)
            # print("P_k", P_k, torch.isnan(P_k).sum(), self.epoch, self.k)
            self.train_data[:, :, 0] = P_k
            io.dump(self.P_save_path + f'/P_{self.k}.pt', P_k)
            self.k = self.half_k * 2 + 2
            if self.cfg.INPUT_DIM == 2:
                self.plot_2d_epoch(P_k[:4000], before_epoch='before')
            elif self.cfg.INPUT_DIM > 2:
                self.plot_highD_Pk(P_k[:4000], before_epoch='before')

    def map_forward_gd(self, data):
        # print("map_forward_gd", torch.isnan(data1).sum())
        expe_W_x_y = self.expe_W_x_y(data)
        # print("expe_W_x_y", torch.isnan(expe_W_x_y).sum())
        # print("data1 - self.cfg.step_a * expe_W_x_y", data.shape, expe_W_x_y.shape)
        P_k = data - self.cfg.step_a * expe_W_x_y
        return P_k

    def expe_W_x_y(self, Pk_x):
        single_num = 10000
        iter_times = int(Pk_x.shape[0] / single_num)
        expe_list = []
        # print("!!!!!!!!!!!Pk_x", Pk_x.abs().mean(), "Pk_y", Pk_y.abs().mean())
        for d in range(iter_times):
            Pk_x_batch = Pk_x[d * single_num:(d + 1) * single_num]
            if d == 1:
                Pk_y_batch = Pk_x[((d - 2) * single_num):]
            else:
                Pk_y_batch = Pk_x[((d - 2) * single_num):((d - 1) * single_num)]
            # print("Pk_x_batch == Pk_y_batch", (Pk_x_batch == Pk_y_batch).sum(), "\n")
            if self.cfg.INPUT_DIM == 2:
                Pk_x_ext = Pk_x_batch.reshape(single_num, 1, -1)
                Pk_y_ext = Pk_y_batch.reshape(1, single_num, -1)
                squ_norm_xy = ((Pk_x_ext - Pk_y_ext)**2).sum(axis=-1)[:, :, None]
                total_expe = (self.w_gradient_part1(squ_norm_xy)
                              * (Pk_x_ext - Pk_y_ext)).mean(axis=1)
                expe_list.append(total_expe)
            else:
                Pk_x_ext = Pk_x_batch.reshape(single_num, 1)
                Pk_y_ext = Pk_y_batch.reshape(1, single_num)

                x_minus_y = (Pk_x_ext - Pk_y_ext)
                # print("all components", Pk_x_ext.abs().max(), x_minus_y.max(), x_minus_y.min())
                # print("x_minus_y.abs().min()", x_minus_y.abs().min())
                x_minus_y_lit = copy.deepcopy(x_minus_y)
                length = len([x_minus_y_lit == 0])
                # print(x_minus_y_lit, torch.randn(length))
                x_minus_y_lit[x_minus_y_lit == 0] = torch.randn(length) * 1e-10
                inverse_xy = 1 / (x_minus_y_lit)
                # print("inverse_xy == float(inf)", (inverse_xy == float("inf")).sum())
                # print((inverse_xy == float("-inf")).sum())
                # inverse_xy[inverse_xy == float("inf")] = 1e10
                # inverse_xy[inverse_xy == float("-inf")] = -1e10
                total_expe = torch.mean(x_minus_y - inverse_xy, 1, True)
                # total_expe = torch.mean(x_minus_y, 1, True) - torch.mean(inverse_xy, 1, True)
                # total_expe = (x_minus_y - 1 / (x_minus_y + 1e-10)).mean()
                # print("total_expe", total_expe)
                expe_list.append(total_expe)
        expe = torch.cat(expe_list, dim=0)
        return expe

    def on_train_epoch_end(self):
        Tx = self.map_forward_nograd(self.train_data[:, :, 0].cuda())
        self._basic_plot(Tx[:10000])

        if self.epoch % self.cfg.epochs == 0:
            torch.save(self.map_t.state_dict(), self.P_save_path + f'/map_{self.k}.pt')
            print("start iterate_dataloader")
            P_k = self.iterate_dataloader()
            io.dump(self.P_save_path + f'/P_{self.k-1}.pt', P_k)
            print("finish_dataloader")
        self.epoch += 1

    def iterate_dataloader(self):
        self.half_k += 1
        self.k = self.half_k * 2 + 1
        P_iterate = self.new_Pk_generator()
        return P_iterate

    def sampler_p0(self):
        if self.cfg.type_data == 'aggreg':
            P0_data = g_data.import_aggre(self.cfg)[0]
        else:
            P0_data = g_data.import_aggre_diffusion_2d(self.cfg)[0]
        return P0_data

    def new_Pk_generator(self):
        P0_data = self.sampler_p0()
        P_iterate = P0_data[:, :, 0].cuda()
        # torch.cat([P0_data[:, :, 0], P0_data[:, :, 2]], axis=0).cuda()
        iterated_map = copy.deepcopy(self.map_t)
        for idx in range(1, self.k):
            if idx % 2 == 0:
                iterated_map.load_state_dict(torch.load(
                    self.P_save_path + f'/map_{idx}.pt', map_location='cuda:0'))
                P_iterate = self.map_forward_nograd(P_iterate.detach()).detach()
                P_iterate = P_iterate[torch.randperm(P_iterate.shape[0])]
            else:
                P_iterate = self.map_forward_gd(P_iterate)
                # print("idx % 2==1", torch.isnan(P_iterate).sum(), P_iterate.abs().mean())
        if self.cfg.debug_jko:
            print("debug: not updating data")
        else:
            print("updating data")
            self.train_data[:, :, 0] = P_iterate.detach().cpu()

        return P_iterate.detach()


class Aggreg_2step_gmap(Aggreg_2step_base, Aggreg_1step_gmap):
    def training_step(self, real_data, _):
        pass

    def aggreg_loss(self, Pk_1_pushed, Pk_2):
        return 0

    def on_train_epoch_end(self):
        if self.epoch % self.cfg.epochs == 0:
            P_k = self.iterate_dataloader()
            io.dump(self.P_save_path + f'/P_{self.k-1}.pt', P_k)
        self.epoch += 1

    def new_Pk_generator(self):
        P0_data = g_data.import_aggre(self.cfg)[0]
        P_iterate = P0_data[:, :, 0]
        for idx in range(1, self.k):
            if idx % 2 == 0:
                pass
            else:
                P_iterate = self.map_forward_gd(P_iterate)
        return P_iterate.detach()

    def w_gradient_part1(self, squ_norm):
        # * Now I fix the function W, I may need to change it later.
        if self.cfg.INPUT_DIM == 2:
            return squ_norm - 1


class Aggreg_diffusion_2step(Aggreg_2step_base):
    def iterate_dataloader(self):
        self.half_k += 1
        self.k = self.half_k * 2 + 1
        P_iterate = self.new_Pk_generator()
        if self.cfg.INPUT_DIM == 2:
            self.update_q_2d(P_iterate.cpu()[:10000])
        elif self.cfg.INPUT_DIM > 2:
            self.update_mu(P_iterate[:10000])
        return P_iterate

    def w_gradient_part1(self, squ_norm):
        if self.cfg.keller_segel:
            return 1 / (2 * np.pi * squ_norm)
        else:
            return 2 * torch.exp(-squ_norm) / np.pi


class Aggreg_diffusion_2step_gmap(Aggreg_diffusion_2step, Porus_gmap):
    # * The loss function, it uses Porus_Tmap loss.
    pass


class Aggreg_diffusion_2step_Tmap(Aggreg_diffusion_2step, Porus_Tmap):
    pass
