from data_paths import TOMOTWIN_TOMO_BASE_DIR, SHREC2021_BASE_DIR, FIXED_TOMOTWIN_TOMOS_PROMPTS_JSON, FIXED_SHREC2021_PROMPTS_JSON, TOMOTWIN_MODEL_FILE
from backup_tools import try_load_model_from_backup_file
import os

#%%
# setting coldstart = True will train a modle with the same architecture as the model_ckpt, but with randomly initialized weights
coldstart = False
gpus = [3]
limit_to_pdbs = []
max_epochs = 15

datamodule_args = {
    "class_args": {
        "tomotwin_tomo_base_dir": TOMOTWIN_TOMO_BASE_DIR,
        "train_val_tomotwin_runs": [],
        "exclusive_val_tomotwin_runs": [],

        "shrec2021_base_dir": SHREC2021_BASE_DIR,
        "train_val_shrec2021_models": ["model_0", "model_1", "model_2", "model_3", "model_4", "model_5", "model_6", "model_7"],
        "exclusive_val_shrec2021_models": ["model_8"], #["model_8"],

        "train_val_subtomos_dir": "data/propicker_fine_tuning_data",

        "limit_to_pdbs": limit_to_pdbs,

        "subtomo_size": 64,
        "subtomo_extraction_strides": [32, 32, 32],
        "max_classes_per_tomo": 10,

        "prompt_type": "dict",
        "prompt_dict_json": FIXED_SHREC2021_PROMPTS_JSON,
        "fixed_prompts": True,

        # "prompt_type": "tomotwin_reference_embedding",
        # "prompt_dict_json": None,
        # "fixed_prompts": False,


        #: prompt_type="dict",
        #: prompt_dict_json=prompt_embeds_dict_json,
        #: fixed_prompts=True,

        "val_frac": 0.01,
        "train_batch_size": 4,
        "val_batch_size": 8,
        "num_workers": 8,
        "seed": 42,
    },
    "prepare_data_args": {
        "setup_tomotwin_reference_embeddings": True,
        "skip_existing": True,
        "tomotwin_model_file": TOMOTWIN_MODEL_FILE,
        "crop_tomo_fn": lambda x: x[:,:,:],  # specify a function to crop tomograms, useful to fine-tune on little data
    },
}

model_ckpt = "PATH TO MODEL CHECKPOINT YOU WANT TO FINE-TUNE"

model_args = {
    "ckpt": model_ckpt,
    "backup_file": None,
    "model": try_load_model_from_backup_file(model_ckpt=model_ckpt, backup_file=None),
}

train_loss_args = val_loss_args = {
    "class": "torch.nn.BCELoss",
    "class_args": {
        "reduction": 'none',
    },
    "mask_empty_targets": False,
}


optimizer_args = {
    "class": "torch.optim.Adam",
    "class_args": {
        "lr": 1e-3,
    }
}


augmentation_args = [
    {
        "class": "data.augmentation.Flip",
        "class_args": {
            "axis": (0, 1),
            "prob": 0.3,
        }
    },
    {
        "class": "data.augmentation.Flip",
        "class_args": {
            "axis": (0, 2),
            "prob": 0.3,
        }
    },
    {
        "class": "data.augmentation.Flip",
        "class_args": {
            "axis": (1, 2),
            "prob": 0.3,
        }
    },
]

logdir = "lightning_logs/fine_tune_propicker"
#logdir = "./trash"
model_tag, _ = os.path.splitext(os.path.basename(model_ckpt))
#model_tag = f"{model_tag}_1_upsample_layers"
n_runs = len(datamodule_args['class_args']['train_val_tomotwin_runs']) + len(datamodule_args['class_args']['train_val_shrec2021_models'])
logger_name = f"fine_tune_{model_tag}/shrec2021/pdbs={','.join(datamodule_args['class_args']['limit_to_pdbs'])}/coldstart={coldstart}/runs={n_runs}/subtomo_extraction_strides={datamodule_args['class_args']['subtomo_extraction_strides']}/val_frac={datamodule_args['class_args']['val_frac']}"
print(f"Logging to {logdir}/{logger_name}")
logger = pl.loggers.TensorBoardLogger(logdir, name=logger_name)