import os
import numpy as np
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR as StepLR
import jacinle.io as io
import src.models.porus_media_Tmap as T_system
import src.models.kl_gmap as g_system
import src.utils.plot_utils as PLU
from porous_media import density_rho0
import matplotlib.pyplot as plt
import copy


class Porus_gmap(g_system.Base_gmap, T_system.Porus_Tmap):

    def __init__(self, *kargs, **kwargs):
        super(Porus_gmap, self).__init__(*kargs, **kwargs)
        self.g_positive_params = []
        for p in list(self.map_t.parameters()):
            if hasattr(p, 'be_positive'):
                self.g_positive_params.append(p)

    def plot_1d_density(self, cvx_data=None):
        x_plot = torch.linspace(-self.cfg.plot_size, self.cfg.plot_size,
                                self.cfg.num_grid).reshape(-1, 1)
        self.map_t.load_state_dict(torch.load(self.P_save_path + f'/map_{self.k}.pt'))
        P_k = self.map_forward(self.train_data[:, :, 0].cuda())
        plt.hist(P_k.detach().cpu().numpy(), bins=self.cfg.num_grid,
                 range=[-self.cfg.plot_size, self.cfg.plot_size], density=True, color='C0', alpha=0.5)

        Pk_density = self.get_density(x_plot.cuda())
        Pk_density /= (Pk_density.sum() * (2 * self.cfg.plot_size / self.cfg.num_grid))
        plt.plot(x_plot, Pk_density, c='C0', label='Ours')

        self.plot_gt_rho_1d(x_plot)
        if cvx_data is not None:
            plt.plot(x_plot, cvx_data[self.k - 1], c='C2', label='discrete solution')
        if self.k == 8:
            # self.cfg.iter_proxi:
            plt.legend(prop={"size": 10}, loc='upper right')
        PLU.save_fig(self.image_save_path + f'/density{self.k}.png')

    def plot_2d_density(self):
        x_grid, y_grid, xy_list = self._get_xy_grid()

        x_plot = torch.from_numpy(xy_list).cuda().float()
        Pk_density = self.get_density(x_plot)
        Pk_density /= (Pk_density.sum() * (
            2 * self.cfg.plot_size / self.cfg.num_grid)**2)
        Pk_density = Pk_density.reshape(self.handle.ax_param.num_grid, -1).numpy()
        PLU.surface_alone(
            x_grid, y_grid, Pk_density, ax_params=self.handle.ax_param,
            save_path=self.image_save_path + f'/density{self.k}.png')

    # *Now I don't record the past densities.

    def get_density(self, x_now):
        particle_history = self.get_past_particles(x_now)
        P0_unml = density_rho0(particle_history[-1])
        if self.P0_nml_c > 0:
            P0_nml = P0_unml / self.P0_nml_c
            density_additor = P0_nml.log()
        else:
            density_additor = P0_unml.log()
        iterated_map = copy.deepcopy(self.map_t)

        for idx in range(1, self.k + 1):
            iterated_map.load_state_dict(torch.load(self.P_save_path + f'/map_{idx}.pt'))
            x_variable = particle_history[-idx].clone()
            x_variable.requires_grad = True
            optimizer_x = optim.Adam([x_variable])
            optimizer_g = optim.Adam(iterated_map.parameters())
            optimizer_x.zero_grad()
            optimizer_g.zero_grad()

            g_of_x = iterated_map(x_variable).sum()
            grad_x = torch.autograd.grad(
                g_of_x, x_variable, create_graph=True)[0]
            optimizer_x.zero_grad()

            density_additor -= self.get_logHessian(grad_x, x_variable)
        return density_additor.exp().detach().cpu()

    def get_past_particles(self, x_now):
        # * particle_history is [x_{k-1}, ... x_0]
        particle_history = []
        iterated_map = copy.deepcopy(self.map_t)
        for idx in range(self.k, 0, -1):
            iterated_map.load_state_dict(torch.load(self.P_save_path + f'/map_{idx}.pt'))
            x_now = self.back_particles(x_now, iterated_map).detach()
            print(idx - 1, x_now.max())
            particle_history.append(x_now)
        return particle_history

    def back_particles(self, x_now, g_now):
        x_variable = x_now.detach().clone().cuda()
        x_variable.requires_grad = True
        # TODO this lr is 1e-5/1e-4 for aggreg+diffusion
        optimizer_x = optim.Adam([x_variable], lr=1e-4)
        optimizer_g = optim.Adam(g_now.parameters())
        x_before_update = torch.zeros_like(x_variable)
        idx = 0
        error = 1
        # TODO this tol is 0.002 for aggreg+diffusion
        while idx < 100 or error > 2e-4:
            if idx > 50000:
                break
            idx += 1
            x_before_update = x_variable.detach().clone()
            if self.cfg.INPUT_DIM == 1:
                loss_pushback = (-x_variable * x_now + g_now(x_variable)).sum()
            else:
                loss_pushback = (-(x_variable * x_now).sum(axis=1,
                                 keepdim=True) + g_now(x_variable)).sum()
            optimizer_x.zero_grad()
            optimizer_g.zero_grad()
            loss_pushback.backward()
            optimizer_x.step()
            error = torch.norm(x_before_update - x_variable)
            print(error)
        return x_variable

    def get_logHessian(self, grad_x, x_variable):
        if self.cfg.INPUT_DIM == 1:
            grad_x.sum().backward()
            return x_variable.grad.log()
        else:
            return self.get_logHessian_2d(grad_x, x_variable)

    def get_logHessian_2d(self, grad_x, x_variable):
        hessian = torch.cat(
            [
                torch.autograd.grad(outputs=grad_x[:, d], inputs=x_variable, grad_outputs=torch.ones(
                    x_variable.size()[0]).float().to(x_variable),
                    retain_graph=True)[0][:, None, :]
                for d in range(2)
            ],
            dim=1
        )
        return torch.logdet(hessian)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_data, batch_size=self.cfg.N_TEST, shuffle=False, num_workers=4)

    def test_step(self, *args, **kwargs):
        pass

    def on_test_epoch_end(self):
        norm_const = self.cfg.plot_size * 2 / self.cfg.num_grid * self.cfg.N_TRAIN_SAMPLES
        x_plot = torch.linspace(-self.cfg.plot_size, self.cfg.plot_size,
                                self.cfg.num_grid).reshape(-1, 1)
        self.plot_1d_first_epoch(x_plot, norm_const)
        cvx_path = 'cvx_data/a' + \
            f'{int(-np.log10(self.cfg.step_a))}' + '_discrete_gt.mat'
        if os.path.exists(cvx_path):
            cvx_data = io.load(cvx_path)['discrete_record']

        for _ in range(self.cfg.iter_proxi):
            with torch.enable_grad():
                self.plot_1d_density(cvx_data)
                self.iterate_dataloader()


class Aggreg_diffusion_1step_gmap(Porus_gmap):
    def aggreg_loss(self, Pk_1_pushed, Pk_2):
        # * Now I fix the function W, I may need to change it later.
        Pk_2_pushed = self.map_forward(Pk_2)
        inside_w = Pk_1_pushed.reshape(self.cfg.BATCH_SIZE, 1, -1) - \
            Pk_2_pushed.reshape(1, self.cfg.BATCH_SIZE, -1)
        squared_norm = (inside_w**2).sum(axis=2).reshape(-1)
        expec_w = torch.exp(-squared_norm) / np.pi
        return expec_w.mean()

    def on_train_epoch_end(self):
        Tx = self.map_forward_nograd(self.train_data[:, :, 0].cuda())
        self.plot_2d_epoch(Tx)

        if self.epoch % self.cfg.epochs == 0:
            torch.save(self.map_t.state_dict(), self.P_save_path + f'/map_{self.k}.pt')
            # self.plot_2d_density()
            P_k = self.iterate_dataloader()
            io.dump(self.P_save_path + f'/P_{self.k-1}.pt', P_k)
        self.epoch += 1
