from __future__ import print_function
import numpy as np
import torch
import pytorch_lightning as pl
from src.train import train
import src.utils.log_utils as LLU
import src.datamodules.generate_data as g_data
from configs.cfg import CfgPorous
cfg=CfgPorous()

if cfg.type_data == 'aggreg':
    if cfg.INPUT_DIM == 2:
        def density_rho0(x):
            return torch.exp(-((x / 0.5)**2).sum(axis=1) / 2) / (0.5 * np.sqrt(2 * np.pi))
    elif cfg.INPUT_DIM == 1:
        def density_rho0(x):
            return torch.exp(-x**2 / 2) / np.sqrt(2 * np.pi)

elif cfg.aggreg:
    def density_rho0(x):
        # x is a torch.Tensor
        return 1 / 36 * (torch.max(x.abs(), axis=1)[0] < 3)
else:
    # *Q is uniform determined by P. trial cannot determine the p0 anymore.
    cfg.alpha = cfg.INPUT_DIM / (cfg.INPUT_DIM * (cfg.porous_m - 1) + 2)
    cfg.beta = cfg.alpha / cfg.INPUT_DIM
    cfg.k_value = cfg.alpha * (cfg.porous_m - 1) / (2 * cfg.porous_m * cfg.INPUT_DIM)

    def density_rho0(x):
        if cfg.INPUT_DIM == 2:
            x_norm = (x**2).sum(axis=-1)
        elif cfg.INPUT_DIM == 1:
            x_norm = x**2
        inside_relu = cfg.C_constant - cfg.k_value * \
            x_norm * cfg.t0**(-2 * cfg.beta)
        return cfg.t0**(-cfg.alpha) * (inside_relu * (inside_relu > 0))**(1 / (cfg.porous_m - 1))

    def density_rho_t(x, t_now):
        if cfg.INPUT_DIM == 2:
            x_norm = (x**2).sum(axis=-1)
        elif cfg.INPUT_DIM == 1:
            x_norm = x**2
        inside_relu = cfg.C_constant - cfg.k_value * \
            x_norm * t_now**(-2 * cfg.beta)
        return t_now**(-cfg.alpha) * (inside_relu * (inside_relu > 0))**(1 / (cfg.porous_m - 1))

if __name__ == "__main__":
    results_save_path, image_save_path, P_save_path, results = LLU.init_path(cfg)
    if cfg.fb_method:
        if cfg.type_data == 'aggreg':
            #* aggregation function            
            total_data, volume_q = g_data.import_aggre(cfg)
            from src.models.aggreg_map import Aggreg_2step_gmap as light_system
        else:
            total_data, volume_q = g_data.import_aggre_diffusion_2d(cfg)
            #* aggregation-diffusion function            
            if cfg.map_type == 'nabla_g':
                from src.models.aggreg_map import Aggreg_diffusion_2step_gmap as light_system
            else:
                from src.models.aggreg_map import Aggreg_diffusion_2step_Tmap as light_system
    else:
        if cfg.type_data == 'aggreg':
            #* aggregation function
            total_data, volume_q = g_data.import_aggre(cfg)
            from src.models.aggreg_map import Aggreg_1step_gmap as light_system
        elif cfg.aggreg:
            #* aggregation-diffusion function
            total_data, volume_q = g_data.import_aggre_diffusion_2d(cfg)
            from src.models.porus_media_gmap import Aggreg_diffusion_1step_gmap as light_system
        else:
            #* porous media function
            total_data, volume_q = g_data.import_Barenblatt(cfg, density_rho0)
            if cfg.map_type == 'nabla_g':
                from src.models.porus_media_gmap import Porus_gmap as light_system
            elif cfg.map_type == 'T':
                from src.models.porus_media_Tmap import Porus_Tmap as light_system
    model = light_system(
        cfg, volume_q, total_data, P_save_path, image_save_path=image_save_path)

    if cfg.test:
        trainer = pl.Trainer(gpus=1, max_epochs=cfg.iter_proxi)
        trainer.test(model)
    else:
        train(model, cfg, results_save_path)
