import pyrootutils
root = pyrootutils.setup_root(
    search_from=__file__,
    indicator=[".git"],
    pythonpath=True,
    dotenv=True,
)

import os
import glob
import shutil

import torch
from omegaconf import DictConfig, OmegaConf
import hydra
from hydra.utils import instantiate
from lightning.pytorch import seed_everything
from lightning.pytorch.utilities import rank_zero_only
from src.utils.misc import shuffle_dataset
# from src.utils.eval import linear_probe
import src.utils.register_resolvers   # load costum resolvers


def get_last_ckpt(log_root_dir: str):
    pattern = f'{log_root_dir}/*/checkpoints/*.ckpt'
    ckpts = sorted(glob.glob(pattern), key=os.path.getmtime)
    last_ckpt = ckpts[-1] if ckpts else None
    return last_ckpt


def get_train_loaders(config: DictConfig):
    if config.test:
        train_loader = instantiate(config.data, dataset={'split': 'test'})
    else:
        train_loader = instantiate(config.data, dataset={'split': 'train'}, shuffle=True)
        #train_loader = instantiate(config.data, dataset={'split': 'test'}, shuffle=True)
    return train_loader


def get_loaders(config: DictConfig):
    train_loader = get_train_loaders(config)

    val_loaders = []
    for split in config.valdatasets: 
        options = {}
        if hasattr(config, 'valoptions'):
            if hasattr(config.valoptions, 'dataset'):
                # convert config.evaldataset to dict
                options = OmegaConf.to_container(config.valoptions.dataset, resolve=True)
        options['split'] = split
        # val_loader = instantiate(config.data, dataset=options, shuffle=True)
        val_loader = instantiate(config.data, dataset=options)
        val_loaders.append(val_loader)
    return train_loader, val_loaders


@hydra.main(version_base=None, config_path='../config', config_name='train')
def main(config: DictConfig):
    # Faster, but less precise
    torch.set_float32_matmul_precision("high")

    trainer = instantiate(config.trainer)
    if trainer.global_rank == 0:
        print(OmegaConf.to_yaml(config))

    if config.seed is not None:
        seed_everything(config.seed, workers=True)
    model = instantiate(config.model)
    model.save_hyperparameters(config)
    # if not config.test:
    #     model = torch.compile(model)
    
    if config.resume:
        last_ckpt = get_last_ckpt(trainer.logger.root_dir) 
    else: 
        last_ckpt = None
        shutil.rmtree(trainer.logger.root_dir, ignore_errors=True)

    train_loader, val_loaders = get_loaders(config)
    
    fit_kwargs = dict(model=model, 
                      train_dataloaders=train_loader, 
                      val_dataloaders=val_loaders)
    if last_ckpt is not None:
        fit_kwargs['ckpt_path'] = last_ckpt
        print(f'set to load model: {last_ckpt}')
    trainer.fit(**fit_kwargs)

    return model


if __name__ == '__main__':
    main()