import os
import numpy as np
import torch
import src.models.generate_NN as g_NN
import src.utils.plot_utils as PLU
import src.datamodules.generate_data as g_data
import src.models.kl_Tmap as T_system
from porous_media import density_rho0
import jacinle.io as io
from torch.optim.lr_scheduler import StepLR as StepLR
import matplotlib.pyplot as plt
import copy


class Porus_Tmap(T_system.Base_Tmap):
    def __init__(self, cfg, density_q: float = 1.0, train_data: torch.Tensor = None, P_save_path: str = None, image_save_path: str = None):
        # * density_q here is Omega because q is uniform distribution
        # * I temporarily make models saved under P_save_path as well.
        super().__init__()
        self.cfg = cfg

        self.density_q = density_q
        self.h, self.map_t = g_NN.generate_monge_NN(cfg)
        self.epoch = 1
        self.k = 1
        self.P_save_path = P_save_path
        self.image_save_path = image_save_path

        self.train_data = train_data
        if self.cfg.ratio_h:
            self._loss_h_function = self.ratio_loss_h
        else:
            self._loss_h_function = self.basic_loss_h

        self.automatic_optimization = False

        cvx_path = 'cvx_data/a' + \
            f'{int(-np.log10(self.cfg.step_a))}' + '_discrete_gt.mat'
        if os.path.exists(cvx_path):
            self.cvx_data = io.load(cvx_path)['discrete_record']
        else:
            self.cvx_data = None
        self.P0_nml_c = 0
        self.handle = PLU.DIM2_PLOT(
            left_place=-self.cfg.plot_size, right_place=self.cfg.plot_size)

    def ratio_loss_h(self, q_data, Tx):
        E_h_over_Pk = (self.cfg.porous_m * self.h(Tx) **
                       (self.cfg.porous_m - 1)) / (self.cfg.porous_m - 1)
        E_h_over_q = self.h(q_data)**self.cfg.porous_m
        sum_of_expec = -E_h_over_Pk.mean() + E_h_over_q.mean()
        return sum_of_expec / self.density_q**(self.cfg.porous_m - 1)

    def basic_loss_h(self, q_data, Tx):
        E_h_over_Pk = self.h(Tx).mean()
        E_h_over_q = ((self.cfg.porous_m - 1) / self.cfg.porous_m * self.h(q_data)
                      )**(self.cfg.porous_m / (self.cfg.porous_m - 1))
        sum_of_expec = -E_h_over_Pk + E_h_over_q.mean()
        return sum_of_expec / self.density_q**(self.cfg.porous_m - 1)

    def training_step(self, real_data, _):
        optimizer_t, optimizer_h = self.optimizers()
        Pk_data = real_data[:, :, 0]
        q_data = real_data[:, :, 1]

        h_ot_loss_value_batch = 0
        optimizer_t.zero_grad()
        ######################################################
        #                Inner Loop Begin                   #
        ######################################################
        for _ in range(1, self.cfg.N_inner_ITERS + 1):
            optimizer_h.zero_grad()
            optimizer_t.zero_grad()

            Tx = Pk_data if self.cfg.debug_h else self.map_forward(Pk_data)
            loss_h = self._loss_h_function(q_data, Tx) * self.cfg.loss_amplifier
            # print(loss_h)
            h_ot_loss_value_batch += loss_h
            if self.cfg.debug_h:
                loss_h.backward()
            else:
                loss_h.backward(retain_graph=True)

            optimizer_h.step()

        ######################################################
        #                Inner Loop Ends                     #
        ######################################################
        for _ in range(1, self.cfg.N_outer_ITERS + 1):
            if self.cfg.debug_h == False:
                optimizer_h.zero_grad()
                optimizer_t.zero_grad()
                Tx = self.map_forward(Pk_data)
                loss_h = self._loss_h_function(q_data, Tx) * self.cfg.loss_amplifier
                loss_h.backward(retain_graph=True)

                w2_loss = -(Pk_data - Tx).pow(2).sum(dim=1).mean() / 2 / self.cfg.step_a
                # *nn_loss is 0 if it's normal T otherwise
                nn_loss = self.nn_loss()
                try:
                    aggreg_loss = self.aggreg_loss(Tx, real_data[:, :, 2])
                except:
                    aggreg_loss = 0
                remaining_map_loss = self.cfg.loss_amplifier * \
                    (w2_loss + nn_loss + aggreg_loss) / self.cfg.diffusion_coeff
                remaining_map_loss.backward()
                for p in list(self.map_t.parameters()):
                    p.grad.copy_(-p.grad)

                # ! update T
                optimizer_t.step()
            else:
                w2_loss = 0

        h_ot_loss_value_batch /= (self.cfg.N_inner_ITERS)
        self.log_dict(
            {'h_ot_loss': h_ot_loss_value_batch,
             'w2_loss': w2_loss, 'aggreg_loss': aggreg_loss, 'total': remaining_map_loss, }, prog_bar=True)
        return None

    def aggreg_loss(self, Pk_1_pushed, Pk_2):
        return 0

    def update_q_1d(self, P_k):
        # * P_k should be on cpu
        max_bound = P_k.abs().max(axis=0)[0] * self.cfg.q_bound_scale
        volume = torch.prod(max_bound * 2)
        self.train_data[:, :, 1] = torch.rand(
            self.cfg.N_TRAIN_SAMPLES, self.cfg.INPUT_DIM) * max_bound * 2 - max_bound
        self.density_q = volume.item()

    def update_q_2d(self, P_k):
        # * P_k should be on cpu
        max_bound = min(5, P_k.abs().max().item() * self.cfg.q_bound_scale)
        self.density_q = (max_bound * 2)**2
        print("volume", self.density_q)
        self.train_data[:, :, 1] = torch.rand(
            self.cfg.N_TRAIN_SAMPLES, self.cfg.INPUT_DIM) * max_bound * 2 - max_bound

    def on_train_epoch_start(self):
        if self.epoch == 1:
            PLU.sns_jointplot_alone(
                self.train_data[:1000, 0, 0].numpy(), self.train_data[:1000, 1, 0].numpy(), self.image_save_path + f'/epoch0.png')

    def on_train_epoch_end(self):
        Tx = self.map_forward_nograd(self.train_data[:, :, 0].cuda())
        self._basic_plot(Tx)

        if self.epoch % self.cfg.epochs == 0:
            torch.save(self.map_t.state_dict(), self.P_save_path + f'/map_{self.k}.pt')
            if self.cfg.INPUT_DIM == 1:
                self.plot_1d_density(self.cvx_data)
            P_k = self.iterate_dataloader()
            io.dump(self.P_save_path + f'/P_{self.k-1}.pt', P_k)
        self.epoch += 1

    def iterate_dataloader(self):
        # P_k should be [cfg.NUM_TRAIN, cfg.INPUT_DIM]
        self.k += 1
        P_iterate = self.new_Pk_generator()
        if self.cfg.INPUT_DIM == 1:
            self.update_q_1d(P_iterate.cpu())
        elif self.cfg.INPUT_DIM == 2:
            self.update_q_2d(P_iterate.cpu())
        return P_iterate

    def new_Pk_generator(self):
        # * first case is aggreg (Aggreg_1step_gmap use), second case is aggreg+diffusion (Aggreg_diffusion_1step_gmap use), third case is diffusion
        if self.cfg.type_data == 'aggreg':
            P0_data = g_data.import_aggre(self.cfg)[0]
            P_iterate = torch.cat([P0_data[:, :, 0], P0_data[:, :, 2]], axis=0).cuda()
        elif self.cfg.aggreg:
            P0_data = g_data.import_aggre_diffusion_2d(self.cfg)[0]
            P_iterate = torch.cat([P0_data[:, :, 0], P0_data[:, :, 2]], axis=0).cuda()
        else:
            P_iterate = g_data.import_Barenblatt(self.cfg, density_rho0)[
                0][:, :, 0].cuda()
        iterated_map = copy.deepcopy(self.map_t)
        for idx in range(1, self.k):
            iterated_map.load_state_dict(torch.load(
                self.P_save_path + f'/map_{idx}.pt', map_location='cuda:0'))
            P_iterate = self.map_forward_nograd(P_iterate.detach())

        if self.cfg.type_data == 'aggreg' or self.cfg.aggreg:
            self.train_data[:, :, 0] = P_iterate.detach().cpu()[:self.cfg.N_TRAIN_SAMPLES]
            self.train_data[:, :, 2] = P_iterate.detach().cpu()[-self.cfg.N_TRAIN_SAMPLES:]
        else:
            self.train_data[:, :, 0] = P_iterate.detach().cpu()
        return P_iterate.detach()

    def _basic_plot(self, Tx):
        if self.cfg.INPUT_DIM == 2:
            self.plot_2d_epoch(Tx)
        elif self.cfg.INPUT_DIM == 1:
            self.plot_1d_epoch(Tx)

    def plot_1d_density(self):
        pass

    def plot_1d_first_epoch(self, x_plot):
        plt.hist(self.train_data[:, :, 0].numpy(), bins=self.cfg.num_grid,
                 range=[-self.cfg.plot_size, self.cfg.plot_size], density=True, color='C0', alpha=0.5)[0]
        self.plot_gt_rho_1d(x_plot, t_now=self.cfg.t0)
        PLU.save_fig(self.image_save_path + f'/epoch0.png')

    def plot_1d_epoch(self, P_k):
        x_plot = torch.linspace(-self.cfg.plot_size, self.cfg.plot_size,
                                self.cfg.num_grid).reshape(-1, 1)
        if self.epoch == 1:
            self.plot_1d_first_epoch(x_plot)

        Pk_hist = plt.hist(
            P_k.detach().cpu().numpy(), bins=self.cfg.num_grid,
            range=[-self.cfg.plot_size, self.cfg.plot_size], density=True, color='C0', alpha=0.5)[0]

        self.plot_gt_rho_1d(x_plot)
        PLU.save_fig(self.image_save_path + f'/epoch{self.epoch}.png')
        if self.cfg.type_data != 'aggreg':
            self.plot_h_comparison_1d(Pk_hist, x_plot)

    def plot_gt_rho_1d(self, x_plot, t_now=None):
        if self.cfg.type_data == 'aggreg':
            self.plot_gt_rho_aggreg_1d(x_plot)
        else:
            self.plot_gt_rho_Barenblatt_1d(x_plot, t_now)

    def plot_gt_rho_Barenblatt_1d(self, x_plot, t_now=None):
        if t_now == None:
            t_now = self.cfg.t0 + self.k * self.cfg.step_a
        inside_relu = self.cfg.C_constant - self.cfg.k_value * \
            x_plot**2 * t_now**(-2 * self.cfg.beta)
        gt_rho = t_now**(-self.cfg.alpha) * (
            inside_relu * (inside_relu > 0))**(1 / (self.cfg.porous_m - 1))
        self.P0_nml_c = (gt_rho.sum() * (2 * self.cfg.plot_size / self.cfg.num_grid))
        gt_rho /= self.P0_nml_c
        plt.plot(x_plot, gt_rho, c='C1', label='exact density')

    def plot_h_comparison_1d(self, Pk_hist, x_plot):
        x_axis = x_plot.reshape(-1, 1).cuda()
        plt.plot(x_plot, torch.clamp(self.h(x_axis).detach().cpu(), min=-1, max=10), c='C0')
        if self.cfg.ratio_h:
            self.plot_gt_ratio_h_1d(Pk_hist, x_plot)
        else:
            self.plot_gt_original_h_1d(Pk_hist, x_plot)

    def plot_gt_ratio_h_1d(self, Pk_hist, x_plot):
        plt.plot(x_plot, (Pk_hist * self.density_q), c='C1')
        PLU.save_fig(self.image_save_path + f'/epoch{self.epoch}_h.png')

    def plot_gt_original_h_1d(self, Pk_hist, x_plot):
        plt.plot(x_plot, self.cfg.porous_m * (Pk_hist * self.density_q) **
                 (self.cfg.porous_m - 1) / (self.cfg.porous_m - 1), c='C1')
        PLU.save_fig(self.image_save_path + f'/epoch{self.epoch}_h.png')

    def plot_2d_epoch(self, P_k, before_epoch='after'):
        Pk_plot = P_k.detach().cpu().numpy()
        Pk_hist = plt.hist2d(
            Pk_plot[:, 0], Pk_plot[:, 1], bins=self.cfg.num_grid,
            range=[[-self.cfg.plot_size, self.cfg.plot_size]] * 2, density=True, alpha=0.5)[0]
        plt.savefig(self.image_save_path + f'/epoch{self.epoch}_Pk_2d_{before_epoch}.png')
        if self.cfg.type_data != 'aggreg' and self.cfg.aggreg != True:
            self.plot_gt_rho_2d()
        self.plot_h_comparison_2d(Pk_hist, before_epoch=before_epoch)

    def plot_gt_rho_2d(self):
        x_grid, y_grid, xy_list = self._get_xy_grid()
        from porous_media import density_rho_t
        t_now = self.cfg.t0 + self.k * self.cfg.step_a
        gt_rho = density_rho_t(xy_list, t_now)

        self.P0_nml_c = (gt_rho.sum() * (2 * self.cfg.plot_size / self.cfg.num_grid)**2)
        gt_rho /= self.P0_nml_c
        PLU.surface_alone(
            x_grid, y_grid, gt_rho.reshape(self.cfg.num_grid, -1), ax_params=self.handle.ax_param, save_path=self.image_save_path + f'/epoch{self.epoch}_gt.png')

    def plot_h_comparison_2d(self, Pk_hist, before_epoch='after'):
        x_grid, y_grid, xy_list = self._get_xy_grid()
        if self.cfg.ratio_h:
            Pk_over_Q = Pk_hist * self.density_q
        else:
            Pk_over_Q = Pk_hist
        PLU.surface_alone(
            x_grid, y_grid, Pk_over_Q, ax_params=self.handle.ax_param,
            save_path=self.image_save_path + f'/epoch{self.epoch}_Pk_Q_surface_{before_epoch}.png')
        if before_epoch == 'after' and self.cfg.ratio_h:
            x_axis = torch.from_numpy(xy_list).cuda().float()
            self.plot_single_h_2d(x_grid, y_grid, x_axis, self.handle.ax_param)
