import torch
from src.models.aggreg_map import Aggreg_diffusion_2step
from src.models.kl_Tmap import GM_Tmap
import src.datamodules.generate_data as g_data


class kl_2step_base(Aggreg_diffusion_2step):
    def map_forward_gd(self, data):
        data = data.to(self.density_q[2].device)
        single_num = 50000
        iter_times = int(data.shape[0] / single_num)
        nabla_Vx_list = []
        for d in range(iter_times):
            data_batch = data[d * single_num:(d + 1) * single_num]
            nabla_v_batch = data_batch - self.cfg.step_a * self.nabla_Vx(data_batch)
            nabla_Vx_list.append(nabla_v_batch)

        P_k = torch.cat(nabla_Vx_list, dim=0)
        return P_k

    def nabla_Vx(self, Tx):
        mean_q_list, cov_q, weights = self.density_q
        numer, density_denom = 0, 0
        for idx in range(self.cfg.NUM_GMM_COMPONENT[1]):
            gauss_density = weights[idx] * \
                self.gauss_density(Tx, mean_q_list[idx], cov_q[idx]).reshape(-1, 1)
            inside_exp = (Tx - mean_q_list[idx]) @ torch.inverse(cov_q[idx])

            numer += (gauss_density * inside_exp)
            density_denom += (gauss_density)
        print("debug:", gauss_density.abs().max(),
              inside_exp.abs().max(), (numer / density_denom).abs().max())
        return numer * 1e10 / (density_denom * 1e10 + 1e-6)

    def log_q_loss(self, _):
        return 0


class GM_2step_Tmap(kl_2step_base, GM_Tmap):
    def sampler_p0(self):
        P0_data = g_data.importa_samp_data_gmm(self.cfg)
        return P0_data
