import os
import logging
import os.path as osp
from logging import Logger

import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from util.train_utils import find_last_checkpoint_path
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import seed_everything
import torch.nn as nn
from tasks.TSTT_pretrain import TSTTPretrainTask
from util.environment_setup import setup_environment
# Call the setup function
setup_environment(config_file='util/environment.yaml')

import hashlib

def generate_version_name_finetune(cfg):
    if float(cfg.token_mask_ratio) > 0.3:
        hyperparams = f"pretrained={cfg.finetune_pretrained}_freeze_layers={cfg.freeze_layers}_token_mask_ratio={cfg.token_mask_ratio}_channel_mask_ratio={cfg.channel_mask_ratio}_test"
    else:
        hyperparams = f"pretrained={cfg.finetune_pretrained}_freeze_layers={cfg.freeze_layers}_token_mask_ratio={cfg.token_mask_ratio}_channel_mask_ratio={cfg.channel_mask_ratio}"
    return hashlib.md5(hyperparams.encode()).hexdigest()

# Register the environment variable resolver
OmegaConf.register_new_resolver("env", lambda key: os.getenv(key))

OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)

logger: Logger = logging.getLogger(__name__)

def train(cfg: DictConfig):
    seed_everything(cfg.seed)

    version_name = generate_version_name_finetune(cfg)

    # DataLoader
    print("===> Loading datasets")
    data_module = hydra.utils.instantiate(cfg.data_module)

    # Pytorch Lightning module
    print("===> Start building model")
    model = hydra.utils.instantiate(cfg.task, cfg)
    print(model)
    if(float(cfg.token_mask_ratio) > 0.3):
        cfg.tag_finetune = cfg.tag_finetune + "_new"
    pretrained_checkpoint = os.path.join(cfg.io.checkpoint_dirpath, cfg.tag_finetune, version_name)
    files = os.listdir(pretrained_checkpoint)
    file = [f for f in files if f.startswith("epoch")][0]
    pretrained_checkpoint = os.path.join(pretrained_checkpoint, file)
    # Find the path of the checkpoint it should start with epoch...

    print(f"===> Loading pretrained_checkpoint from {pretrained_checkpoint}")
    model = model.load_from_checkpoint(checkpoint_path=pretrained_checkpoint)


    # Trainer
    print("===> Instantiate trainer")
    if cfg.trainer.strategy == "ddp":
        del cfg.trainer.strategy
        trainer = Trainer(
            **cfg.trainer,
            strategy=DDPStrategy(find_unused_parameters=cfg.find_unused_parameters),
        )
    else:
        trainer = Trainer(
            **cfg.trainer
        )

    print("===> Start validation")
    trainer.validate(model, data_module)
    
    print("===> Start testing")
    trainer.test(model, data_module)



@hydra.main(config_path="./config", config_name="defaults", version_base="1.1")
def run(cfg: DictConfig):
    print(f"PyTorch-Lightning Version: {pl.__version__}")
    print(OmegaConf.to_yaml(cfg, resolve=True))
    train(cfg)


if __name__ == "__main__":
    # Ensure environment variables are set before Hydra processes the config
    #os.environ['HYDRA_FULL_ERROR'] = '1'
    run()



