# parameter sets of model ctlstm.

replace = False
wandb = False
retrain = 1 if replace else 1
seed_during_evaluation = 12345
compile = False

model_name = 'attnhp'
dataloader_name = "generic"

# Hyperparameters used to train on synthetic datasets.
syn_datasets_name = ["hawkes_1_v2", "hawkes_2_v2", "poisson_v2", "self_correct_v2", "stationary_renewal_v2"]
syn_n_training_step = 50000
syn_training_batch_size = 32
syn_evaluation_batch_size = 128
syn_learning_rate = 0.002
syn_model_config = f"syn/{model_name}.yml"
syn_n_warmup_steps = int(syn_n_training_step * 0.2)
syn_n_evaluation_steps = int(0.2 * syn_n_training_step)

# Additional hyperparameter used for evaluation.
syn_dataloader_config = "syn/plot.yml"

# Train MTPP models and evaluate them on synthetic datasets, such as hawkes_1, hawkes_2, poisson, self_correct, and stationary_renewal.
default_train_and_evaluate_on_syn_datasets = [
    {
        "worker": "start.py",
        "job_type": "train",
        "static":
        {
            "no_seed": True,
            "dataloader_name": dataloader_name,
            "n_training_steps": syn_n_training_step,
            "n_evaluation_steps": syn_n_evaluation_steps,
            "n_report_steps": syn_n_evaluation_steps,
            "training_batch_size": syn_training_batch_size,
            "evaluation_batch_size": syn_evaluation_batch_size,
            "n_warmup_steps": syn_n_warmup_steps,
            "model_name": model_name,
            "lr": syn_learning_rate,
            "save_mode": "best",
            "lr_sched": True,
            "op_name": "AdamW",
            "optim_config": "optimizer.yml",
            "model_config": syn_model_config,
            "n_cycles": "0.5",
            "replace": replace,
            "wandb": wandb
        },
        "zip_style":
        {
            'loop_vars':
            {
                "dataset_name": syn_datasets_name,
            }
        }
    },] * retrain + \
    [{
        "worker": "start.py",
        "job_type": "evaluate",
        "static": {
            "seed": seed_during_evaluation,
            "model_name": model_name,
            "lr": syn_learning_rate,
            "dataloader_name": dataloader_name,
            "figure_count": 1,
            "n_training_steps": syn_n_training_step,
            "test_data_name": "test",
            "resolution": "100",
            "used_batch_size": syn_training_batch_size,
            "dataloader_config": syn_dataloader_config,
            "model_config": syn_model_config,
            "replace": replace,
            "combine_used_and_current_dataloader_config": True
            },
        'zip_style': {
            'loop_vars':
            {
                "dataset_name": syn_datasets_name,
            }
        },
        "counting_style": {
            'zip_style':{
                "loop_vars":
                {
                    "subtask_name": ["intensity", "probability", "debug", "spearman_and_l1", "mae_and_f1", "mae_e_and_f1"],
                    "task_name": ["evaluation_per_seq", "evaluation_per_seq", "evaluation_per_seq", "evaluation_dataset", "evaluation_dataset", "evaluation_dataset"]
                }
            }
        }
    }
]


# Hyperparameters used to train on real-world datasets.
realworld_datasets_name = ["amazon", "bookorder", "retweet", "stackoverflow", "taobao", "taxi", "usearthquake", "yelp"]
realworld_dataloader_config = [f'{dataset}/{model_name}_dl.yml' for dataset in realworld_datasets_name]
realworld_model_config = [f'{dataset}/{model_name}.yml' for dataset in realworld_datasets_name]
realworld_training_step = [400000, 20000, 400000, 200000, 80000, 80000, 80000, 200000]
realworld_training_batch_size = [32, 8, 16, 4, 32, 4, 32, 32]
realworld_evaluation_batch_size = [32, 8, 32, 4, 32, 32, 32, 32]
realworld_learning_rate = 0.002
realworld_n_warmup_steps = [int(0.2*training_step) for training_step in realworld_training_step]
realworld_n_evaluation_step = [int(0.2*training_step) for training_step in realworld_training_step]

# Additional hyperparameter used for evaluation.
realworld_dataloader_config_for_evaluation = [f'{dataset}/plot.yml' for dataset in realworld_datasets_name]

default_train_and_evaluate_on_realworld_datasets = [
    {
        "worker": "start.py",
        "job_type": "train",
        "static": 
         {
             "no_seed": True,
             "dataloader_name": dataloader_name,
             "model_name": model_name,
             "lr": realworld_learning_rate,
             "save_mode": "best",
             "lr_sched": True,
             "op_name": "AdamW",
             "optim_config": "optimizer.yml",
             "n_cycles": 0.5,
             "replace": replace,
             "wandb": wandb,
             "compile": compile
         },
         'zip_style':
         {
             'loop_vars':
             {
                "dataset_name": realworld_datasets_name,
                "dataloader_config": realworld_dataloader_config,
                "model_config": realworld_model_config,
                "n_training_steps": realworld_training_step,
                "n_evaluation_steps": realworld_n_evaluation_step,
                "n_report_steps": realworld_n_evaluation_step,
                "training_batch_size": realworld_training_batch_size,
                "evaluation_batch_size": realworld_evaluation_batch_size,
                "n_warmup_steps": realworld_n_warmup_steps,
             }
         }
    },] * retrain + \
    [{
        "worker": "start.py",
        "job_type": "evaluate",
        "static":
        {
            "seed": seed_during_evaluation,
            "model_name": model_name,
            "lr": realworld_learning_rate,
            "dataloader_name": dataloader_name,
            "figure_count": 1,
            "test_data_name": "test",
            "used_dataloader_config": f"{model_name}_dl.yml",
            "resolution": "100",
            "replace": replace,
            "combine_used_and_current_dataloader_config": True
        },
        'zip_style':
        {
            'loop_vars':
            {
                "dataset_name": realworld_datasets_name,
                "model_config": realworld_model_config,
                "n_training_steps": realworld_training_step,
                "dataloader_config": realworld_dataloader_config_for_evaluation,
                "used_batch_size": realworld_training_batch_size,
            }
        },
        "counting_style":
        {
            'zip_style':
            {
                'loop_vars':
                {
                    "subtask_name": ["intensity", "probability", "debug", "mae_and_f1", "mae_e_and_f1"],
                    "task_name": ['evaluation_per_seq', 'evaluation_per_seq', 'evaluation_per_seq', 'evaluation_dataset', 'evaluation_dataset']
                }
            }
        }
    }
]
    
real_world_data_mae_e = {
    "worker": "start.py",
    "job_type": "evaluate",
    "static":
    {
        "seed": seed_during_evaluation,
        "model_name": model_name,
        "lr": realworld_learning_rate,
        "dataloader_name": dataloader_name,
        "figure_count": 1,
        # "training_data_name": 'train',
        "test_data_name": "test",
        "used_dataloader_config": f"{model_name}_dl.yml",
        "resolution": "100",
        "replace": replace,
        "combine_used_and_current_dataloader_config": True
    },
    'zip_style':
    {
        'loop_vars':
        {
            "dataset_name": realworld_datasets_name,
            "model_config": realworld_model_config,
            "n_training_steps": realworld_training_step,
            "dataloader_config": realworld_dataloader_config_for_evaluation,
            "used_batch_size": realworld_training_batch_size,
        }
    },
    "counting_style":
    {
        'zip_style':
        {
            'loop_vars':
            {
                "subtask_name": ['mae_and_f1',],
                "task_name": ['evaluation_dataset',]
            }
        }
    }
}


# Hyperparameters for evaluating outlier detection performance on synthetic datasets.

missing_rates = [0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75]
od_syn_datasets_name = [item for item in syn_datasets_name for _ in missing_rates]
od_missing_syn_datasets = [item + f"_{missing_rate}" for item in syn_datasets_name for missing_rate in missing_rates]
od_dataloader_name = "od_generic"

cppod_evaluate_on_syn_datasets = {
    "worker": "start.py",
    "job_type": "evaluate",
    "static": {
        "seed": "12345",
        "model_name": model_name,
        "lr": syn_learning_rate,
        "dataloader_name": od_dataloader_name,
        "figure_count": 1,
        "n_training_steps": syn_n_training_step,
        "test_data_name": "test",
        "resolution": "100",
        "used_batch_size": syn_training_batch_size,
        "dataloader_config": "syn/cppod_dl.yml",
        "model_config": syn_model_config,
        "replace": replace,
        "combine_used_and_current_dataloader_config": True
    },
    'zip_style': {
        'loop_vars':
        {
            "training_dataset_name": od_syn_datasets_name,
            "dataset_name": od_missing_syn_datasets,
        }
    },
    "counting_style": {
        'zip_style':{
            "loop_vars":
            {
                "subtask_name": ["cppod_evaluation"],
                "task_name": ["evaluation_dataset"]
            }
        }
    }
}


# Hyperparameters for evaluating outlier detection performance on real-world datasets.

od_missing_real_datasets = [item + f"_{missing_rate}" for item in realworld_datasets_name for missing_rate in missing_rates]
od_realworld_datasets_name = [dataset_name for dataset_name in realworld_datasets_name for _ in missing_rates]
od_realworld_model_config = [config for config in realworld_model_config for _ in missing_rates]
od_realworld_training_step = [training_step for training_step in realworld_training_step for _ in missing_rates]
realworld_dataloader_config = [f'{dataset}/cppod_dl.yml' for dataset in realworld_datasets_name for _ in missing_rates]
od_realworld_training_batch_size = [batch_size for batch_size in realworld_training_batch_size for _ in missing_rates]

cppod_evaluate_on_realworld_datasets = {
    "worker": "start.py",
    "job_type": "evaluate",
    "static": {
        "seed": 12345,
        "model_name": model_name,
        "lr": realworld_learning_rate,
        "dataloader_name": od_dataloader_name,
        "figure_count": 1,
        "test_data_name": "test",
        "used_dataloader_config": f"{model_name}_dl.yml",
        "resolution": "100",
        "replace": replace,
        "combine_used_and_current_dataloader_config": True
    },
    'zip_style': {
        'loop_vars':
        {
            "dataset_name": od_missing_real_datasets,
            "training_dataset_name": od_realworld_datasets_name,
            "model_config": od_realworld_model_config,
            "n_training_steps": od_realworld_training_step,
            "dataloader_config": realworld_dataloader_config,
            "used_batch_size": od_realworld_training_batch_size,
        }
    },
    "counting_style": {
        'zip_style':
        {
            'loop_vars':
            {
                "subtask_name": ["cppod_evaluation"],
                "task_name": ["evaluation_dataset"]
            }
        }
    }
}


# Hyperparameters for evaluating outlier detection performance on synthetic datasets.

com_missing_rates = [0, 1, 2, 3, 4]
com_od_syn_datasets_name = [item for item in syn_datasets_name for _ in com_missing_rates]
com_od_missing_syn_datasets = [f"{item}_com_{com_missing_rate}" for item in syn_datasets_name for com_missing_rate in com_missing_rates]
com_od_dataloader_name = "commission"

com_cppod_evaluate_on_syn_datasets = {
    "worker": "start.py",
    "job_type": "evaluate",
    "static": {
        "seed": "12345",
        "model_name": model_name,
        "lr": syn_learning_rate,
        "dataloader_name": com_od_dataloader_name,
        "figure_count": 1,
        "n_training_steps": syn_n_training_step,
        "test_data_name": "test",
        "resolution": "100",
        "used_batch_size": syn_training_batch_size,
        "dataloader_config": "syn/cppod_dl.yml",
        "model_config": syn_model_config,
        "replace": replace,
        "combine_used_and_current_dataloader_config": True
    },
    'zip_style': {
        'loop_vars':
        {
            "training_dataset_name": com_od_syn_datasets_name,
            "dataset_name": com_od_missing_syn_datasets,
        }
    },
    "counting_style": {
        'zip_style':{
            "loop_vars":
            {
                "subtask_name": ["cppod_commission_evaluation"],
                "task_name": ["evaluation_dataset"]
            }
        }
    }
}


# Hyperparameters for evaluating outlier detection performance on real-world datasets.
# realworld_datasets_name = ['retweet']
# realworld_model_config = [f'{dataset}/{model_name}.yml' for dataset in realworld_datasets_name]
# realworld_training_step = [400000,]
# realworld_training_batch_size = [32, ]
# realworld_evaluation_batch_size = [32, ]

com_od_missing_real_datasets = [item + f"_com_{com_missing_rate}" for item in realworld_datasets_name for com_missing_rate in com_missing_rates]
com_od_realworld_datasets_name = [dataset_name for dataset_name in realworld_datasets_name for _ in com_missing_rates]
com_od_realworld_model_config = [config for config in realworld_model_config for _ in com_missing_rates]
com_od_realworld_training_step = [training_step for training_step in realworld_training_step for _ in com_missing_rates]
com_realworld_dataloader_config = [f'{dataset}/cppod_dl.yml' for dataset in realworld_datasets_name for _ in com_missing_rates]
com_od_realworld_training_batch_size = [batch_size for batch_size in realworld_training_batch_size for _ in com_missing_rates]

com_cppod_evaluate_on_realworld_datasets = {
    "worker": "start.py",
    "job_type": "evaluate",
    "static": {
        "seed": 12345,
        "model_name": model_name,
        "lr": realworld_learning_rate,
        "dataloader_name": com_od_dataloader_name,
        "figure_count": 1,
        "test_data_name": "test",
        "used_dataloader_config": f"{model_name}_dl.yml",
        "resolution": "100",
        "replace": replace,
        "combine_used_and_current_dataloader_config": True
    },
    'zip_style': {
        'loop_vars':
        {
            "dataset_name": com_od_missing_real_datasets,
            "training_dataset_name": com_od_realworld_datasets_name,
            "model_config": com_od_realworld_model_config,
            "n_training_steps": com_od_realworld_training_step,
            "dataloader_config": com_realworld_dataloader_config,
            "used_batch_size": com_od_realworld_training_batch_size,
        }
    },
    "counting_style": {
        'zip_style':
        {
            'loop_vars':
            {
                "subtask_name": ["cppod_commission_evaluation"],
                "task_name": ["evaluation_dataset"]
            }
        }
    }
}


# Define the AttNHP hyperparameter list.

attnhp_hyperparameter_list = {
    "default_train_and_evaluate_on_syn_datasets": default_train_and_evaluate_on_syn_datasets,
    "default_train_and_evaluate_on_realworld_datasets": default_train_and_evaluate_on_realworld_datasets,
    "cppod_evaluate_on_syn_datasets": cppod_evaluate_on_syn_datasets,
    "cppod_evaluate_on_realworld_datasets": cppod_evaluate_on_realworld_datasets,
    "commission_cppod_evaluate_on_syn_datasets": com_cppod_evaluate_on_syn_datasets,
    "commission_cppod_evaluate_on_realworld_datasets": com_cppod_evaluate_on_realworld_datasets,
    "real_world_data_mae_e": real_world_data_mae_e
}