import os

import torch

from flow import RectifiedFlow


class TrainConfig:
    def __init__(self):
        self.device = torch.device("cuda:0")


        self.learning_rate = 3e-4
        self.batch_size = 128
        self.iterations = 10000
        self.save_freq = 10000
        self.current_iteration = 0
        self.pde_loss_weight = 10

        self.loss_mode = 'vanilla'
        # self.loss_mode = 'pde_loss'
        #
        self.rectified_type = 'generation'
        self.condition_type = 'a'
        self.conditional_model = False
        # self.conditional_model = True

        # self.sampling = 'vanilla'
        # self.sampling = 'dflow'
        #
        # self.sampling = 'dps'
        # self.dps_weight = 50
        # self.data_weight = 0.2 # 0.5 SDE

        self.sampling = 'eci'
        self.eci_n_mix = 5

        # self.sampling = 'impainting'

        if self.sampling == 'eci':
            self.sampling_steps = 100
        else:
            self.sampling_steps = 100

        self.dataset = 'poisson'
        self.data_channel = 2
        self.train_data_config = {}
        self.test_data_config = {}
        self.data_dir = rf"./saved/{self.dataset}-iter{self.iterations}-b{self.batch_size}-lr{self.learning_rate}-lm{self.loss_mode}"
        if not os.path.exists(self.data_dir): os.makedirs(self.data_dir)

        self.probability_flow = True


        self.trajectory = 'linear'
        # self.repara = 'epsilon_pred'
        # self.repara = 'x_pred'
        self.repara = None

        self.flow = RectifiedFlow
        self.flow_name = rf"{self.trajectory}-{self.repara}"
        self.flow_dir = rf"{self.data_dir}/{self.flow_name}"
        if not os.path.exists(self.flow_dir): os.makedirs(self.flow_dir)

        # self.model = 'unet_small'
        # self.model_config = {'d': 1}

        # self.model = 'songunet'
        # self.model_config = {'img_resolution': 128, 'in_channels':1, 'out_channels':1}


        if self.conditional_model:
            self.model = 'fno_cond'
            self.uncondition_ratio = 0.2
            self.cfg_weight = 1
        else:
            self.model = 'fno'

        self.model_config = {'n_modes': [ 32, 32 ], 'emb_channels': 32, 'hidden_channels': 64,
                             'channels' : self.data_channel,
                             'proj_channels': 256,
                             'n_layers': 4}
        self.model_path = rf"{self.flow_dir}/{self.model}"
        if self.conditional_model:
            self.model_path += rf"_cond_{self.condition_type}"


        self.rectified_model = 'fno_vanilla'
        self.rectified_model_config = {'n_modes': [ 32, 32 ], 'hidden_channels': 64,
                                       'proj_channels': 256,
                                       'n_layers': 4, 'channels' : self.data_channel}
        self.rectified_epoch_num = 100
        self.rectified_batch_size = 128
        if self.sampling == 'eci':
            self.rectified_sampling_num = 128 # 128 for eci sde
        else:
            self.rectified_sampling_num = 1024
        self.rectified_save_interval = 10
        if self.sampling == 'eci':
            self.rectified_epoch_repeat = 800 # 800 for eci sde
        else:
            self.rectified_epoch_repeat = 100
        # self.rectifed_pde_loss_weight = 1000
        self.rectifed_pde_loss_weight = 1e8 # 100 for eci
        self.rectifed_data_loss_weight = 100
        # self.rectifed_pde_loss_weight = 1
        # self.rectifed_data_loss_weight = 1
        self.rectified_current_epoch = 0
        self.rectified_patience = 100
        self.rectified_model_path = rf"{self.flow_dir}/rectified_{self.rectified_model}_{self.rectifed_data_loss_weight}_{self.rectifed_pde_loss_weight}"
        self.rectified_model_path = self.model_path + f'_rectified_{self.sampling}'
        if self.sampling_steps != 100:
            self.rectified_model_path += f'_ns_{self.sampling_steps}'


    def __getitem__(self, key):
        return getattr(self, key)

    def __setitem__(self, key, value):
        return setattr(self, key, value)
