from dataclasses import dataclass, field


@dataclass
class CfgBase:
    fb_method: int = 0  # forward backward method
    debug_jko: int = 0

    optimizer: str = 'Adam'
    BATCH_SIZE: int = 100
    N_TRAIN_SAMPLES: int = 60000
    LR_g: float = 1e-4
    LR_h: float = 1e-4
    INPUT_DIM: int = 2
    log_interval: int = 10
    N_TEST: int = 10000
    nn_activation: str = 'Prelu'
    schedule_learning_rate: int = 1
    lr_schedule_scale: float = 0.1
    lr_schedule_per_epoch: int = 20
    final_actv: str = 'Prelu'
    full_actv: int = 0

    epochs: int = 1  # how many epochs per P_k
    NUM_NEURON_h: int = 32
    NUM_NEURON_map: int = 32
    NUM_LAYERS_h: int = 3
    NUM_LAYERS_g: int = 3
    TRIAL: float = 4.0
    N_inner_ITERS: int = 3
    N_outer_ITERS: int = 1
    load_iter: int = 50
    repeat: int = 0
    batch_nml: int = 1
    dropout: float = 1e-3
    save_h: int = 1
    opacity: float = 1
    scatter_size: float = 8
    dump_epoch: int = 1
    step_a: float = 0.1
    g_activation: str = 'Prelu'
    iter_proxi: int = 80  # how many P_k

    map_type: str = 'T'
    T_res: int = 1
    T_linear: int = 0
    quadr: int = 1

    type_data: str = 'None'
    plot_size: float = 3
    band_width: float = 0.4
    P0_std: float = 1
    P0_equal_Q: int = 0
    NUM_GMM_COMPONENT: int = 0
    # list = field(default_factory=list)
    MEAN: list = field(default_factory=list)
    COV: list = field(default_factory=list)

    debug_h: int = 0
    seed: int = 1


@dataclass
class CfgGMM(CfgBase):
    mu_equal_q: int = 0
    mu_equal_Pk: int = 1
    mu_var: float = 3
    bregman_loss: int = 0

    log_kl: int = 0
    exp_h: int = 1
    exp_h_add_small: int = 1

    NUM_NEURON_map: int = 16
    type_data: str = 'GM'
    TRIAL: float = 13.0
    T_linear: int = 0
    N_inner_ITERS: int = 3
    plot_size: float = 10
    # this is only for test h ability
    # cov_P: float = 5.0
    Pk_path: str = 'data/Results_of_GM/dim_8/GMM_component_[1, 4]/T_nonlinear/res/trial_8.0/no_log_kl/exp_add_small_h/mu_equal_Pk/step_a_0.3/epoch1per_step/innerIter_3_out_1/layers_h3_g2/neuron_512/lr_t_0.001lr_h_0.001/schedule_learning_rate:Yes/lr_schedule:20/iter_proxi_80/batchnml:Yes_dropout:Yes/activ_h_Prelu_g_Prelu/batch_100/storing_P/P_60.pt'

    def get_save_path(self):
        if self.debug_h:
            return './data/debug_h/Results_of_{11}/dim_{5}/GMM_component_{6}/{22}/trial_{4}/{23}_kl/{24}_h/mu_{21}/step_a_{7}/epoch{17}per_step/innerIter_{10}/layers_h{0}_g{12}/neuron_map{1}_h{25}/lr_t_{2}lr_h_{14}/schedule_learning_rate:{18}/lr_schedule:{19}/iter_proxi_{15}/batchnml:{8}_dropout:{16}/activ_h_{13}_g_{20}/batch_{3}'.format(
                self.NUM_LAYERS_h,
                self.NUM_NEURON_map,
                self.LR_g,
                self.BATCH_SIZE,
                self.TRIAL,
                self.INPUT_DIM,
                self.NUM_GMM_COMPONENT,
                self.step_a,
                'Yes' if self.batch_nml else 'No',
                self.optimizer,
                self.N_inner_ITERS,
                self.type_data,
                self.NUM_LAYERS_g,
                self.nn_activation,
                self.LR_h,
                self.iter_proxi,
                'Yes' if self.dropout > 0 else 'No',
                self.epochs,
                'Yes' if self.schedule_learning_rate else 'No',
                self.lr_schedule_per_epoch,
                self.g_activation,
                'equal_Q' if self.mu_equal_q else (
                    'equal_Pk' if self.mu_equal_Pk else 'N0' + str(self.mu_var)),
                self.map_type if self.map_type == 'nabla_g' else self.map_type + ('_linear' if self.T_linear else '_nonlinear') + (
                    '/res' if self.T_res else '/no_res'),
                'log' if self.log_kl else 'no_log',
                ('exp_add_small' if self.exp_h_add_small else 'exp') if self.exp_h else 'no_exp',
                self.NUM_NEURON_h
            )

        else:
            return './data/Results_of_{11}/{27}/dim_{5}/GMM_component_{6}/{28}/{22}/trial_{4}/{23}_kl/{24}_h/mu_{21}/step_a_{7}/epoch{17}per_step/innerIter_{10}_out_{25}/layers_h{0}_g{12}/neuron_ map{1}_h{26}/lr_t_{2}lr_h_{14}/schedule_learning_rate:{18}/lr_schedule:{19}/iter_proxi_{15}/batchnml:{8}_dropout:{16}/activ_h_{13}_g_{20}/batch_{3}'.format(
                self.NUM_LAYERS_h,
                self.NUM_NEURON_map,
                self.LR_g,
                self.BATCH_SIZE,
                self.TRIAL,
                self.INPUT_DIM,
                self.NUM_GMM_COMPONENT,
                self.step_a,
                'Yes' if self.batch_nml else 'No',
                self.optimizer,
                self.N_inner_ITERS,
                self.type_data,
                self.NUM_LAYERS_g,
                self.nn_activation,
                self.LR_h,
                self.iter_proxi,
                'Yes' if self.dropout > 0 else 'No',
                self.epochs,
                'Yes' if self.schedule_learning_rate else 'No',
                self.lr_schedule_per_epoch,
                self.g_activation,
                'equal_Q' if self.mu_equal_q else (
                    'equal_Pk' if self.mu_equal_Pk else 'N0' + str(self.mu_var)),
                self.map_type if self.map_type == 'nabla_g' else self.map_type + ('_linear' if self.T_linear else '_nonlinear') + (
                    '/res' if self.T_res else '/no_res'),
                'log' if self.log_kl else 'no_log',
                ('exp_add_small' if self.exp_h_add_small else 'exp') if self.exp_h else 'no_exp',
                self.N_outer_ITERS,
                self.NUM_NEURON_h,
                'bregman' if self.bregman_loss else 'kl',
                'fb' if self.fb_method else 'non_fb'
            )


@dataclass
class CfgPorous(CfgBase):
    INPUT_DIM: int = 2
    porous_m: int = 3 # the order m in generalized energy
    aggreg: int = 0
    diffusion_coeff: float = 1
    step_a: float = 0.05
    plot_size: float = 1.2
    type_data: str = 'porous'

    border_square: float = 3
    keller_segel: int = 0

    loss_amplifier: float = 1
    TRIAL: float = 6.0
    ratio_h: int = 1
    q_bound_scale: float = 1.15
    band_width: float = 0.1
    num_grid: int = 300
    map_type: str = 'nabla_g'
    full_actv: int = 1
    test: int = 0

    # *only for Barenblatt-Pattle example
    C_constant: float = (3 / 16)**(1 / 3)
    t0: float = 0.002
    # *determined later
    alpha: float = 0
    beta: float = 0
    k_value: float = 0
    quadr: int = 0

    def get_save_path(self):
        return './data/Results_of_{21}/{23}/dim_{5}/m_{6}/{22}/trial_{4}/{11}_h/step_a_{7}/epoch{17}per_step/innerIter_{10}_out_{25}/layers_h{0}_g{12}/neuron_ map{1}_h{26}/lr_t_{2}lr_h_{14}/schedule_learning_rate:{18}/lr_schedule:{19}/iter_proxi_{15}/batchnml:{8}_dropout:{16}/activ_h_{13}_g_{20}/batch_{3}'.format(
            self.NUM_LAYERS_h,
            self.NUM_NEURON_map,
            self.LR_g,
            self.BATCH_SIZE,
            self.TRIAL,
            self.INPUT_DIM,
            self.porous_m,
            self.step_a,
            'Yes' if self.batch_nml else 'No',
            self.optimizer,
            self.N_inner_ITERS,
            'ratio' if self.ratio_h else 'no_ratio',
            self.NUM_LAYERS_g,
            self.nn_activation,
            self.LR_h,
            self.iter_proxi,
            'Yes' if self.dropout > 0 else 'No',
            self.epochs,
            'Yes' if self.schedule_learning_rate else 'No',
            self.lr_schedule_per_epoch,
            self.g_activation,
            self.type_data,
            self.map_type if self.map_type == 'nabla_g' else self.map_type + ('_linear' if self.T_linear else '_nonlinear') + (
                '/res' if self.T_res else '/no_res'),
            'fb' if self.fb_method else 'non_fb',
            0,
            self.N_outer_ITERS,
            self.NUM_NEURON_h)
