from easydict import EasyDict as edict

cfg = edict()

cfg.data = edict()
cfg.data.train_parquet_path = "data/occurance/GBIF.parquet"
cfg.data.taxonomy_level = (
    "species"  # one of 'class', 'order', 'family', 'genus', 'species'
)
cfg.data.llm_type = (
    "Llama-2-7b-hf"  # one of 'Llama-2-7b-hf', 'Llama-2-13b-hf', 'Llama-2-70b-hf'
)
cfg.data.text_embeddings_path = (
    "data/text_embeddings/" + f"{cfg.data.taxonomy_level}_{cfg.data.llm_type}.npy"
)
cfg.data.env_cov_path = "data/env_cov/bioclim_elevation_scaled_v2.npy"

cfg.model = edict()
cfg.model.transform = "sht"  # one of 'sht', 'fft'
cfg.model.filter_type = "non-linear"  # one of 'linear', 'non-linear'
cfg.model.operator_type = "vector"  # one of 'diagonal', 'vector'
cfg.model.num_layers = 2
cfg.model.encoder_layers = 1
cfg.model.scale_factor = 4
cfg.model.embed_dim = 128
cfg.model.in_chans = 20  # 20 for env_cov, 1 for occurance
cfg.model.out_chans = 1
cfg.model.img_size = (900, 1800)  # (height, 2*height)

cfg.train = edict()
cfg.train.seed = 42
cfg.train.batch_size = 1
cfg.train.shuffle = True
cfg.train.num_workers = 12
cfg.train.num_epochs = 10
cfg.train.lr = 1e-5
cfg.train.accumulate_grad_batches = 64
cfg.train.device = "cuda"
cfg.train.devices = 1

cfg.loss = edict()
cfg.loss.type = "RAL"  # one of 'RAL', 'ASL', 'AN-full', 'ME-full'
cfg.loss.gamma_neg = 4
cfg.loss.gamma_pos = 2
cfg.loss.alpha = 10

cfg.checkpoint = edict()
cfg.checkpoint.dirpath = "./checkpoints"
cfg.checkpoint.freq = 1  # in epochs

cfg.best_model = edict()
cfg.best_model.path = "./best_model"
cfg.best_model.threshold = 0.5
cfg.best_model.env_cov = False

cfg.eval = edict()
cfg.eval.model_path = "checkpoints/epoch=9-val_loss=0.01.ckpt"
cfg.eval.threshold = 0.95
cfg.eval.env_cov = False
cfg.eval.test_parquet_path = "data/occurance/GBIF_test.parquet"
cfg.eval.unseen_parquet_path = "data/occurance/GBIF_unseen.parquet"
cfg.eval.text_embeddings_path = (
    "data/unseen_text_embeddings/"
    + f"{cfg.data.taxonomy_level}_{cfg.data.llm_type}.npy"
)
