import pyrootutils
root = pyrootutils.setup_root(
    search_from=__file__,
    indicator=[".git"],
    pythonpath=True,
    dotenv=True,
)

import os
import glob
import shutil

import torch
from omegaconf import DictConfig, OmegaConf
import hydra
from hydra.utils import instantiate
from lightning.pytorch import seed_everything
from lightning.pytorch.utilities import rank_zero_only

from src.utils.animate import load_config_and_model
from src.train import get_loaders, get_last_ckpt


# def unfreeze_modules(model, module_names):
#     '''
#     Args:
#         model (nn.Module): model to unfreeze
#         module_names (list): list of module names to unfreeze
#     Output:
        
#     '''
#     for name, module in model.named_children():
#         if name in module_names:
#             for param in module.parameters():
#                 param.requires_grad = True
#         else:
#             unfreeze_modules(module, module_names)


# def instantiate_modules(config_modules):
#     new_modules = dict()
#     for key, value in OmegaConf.to_container(config_modules, resolve=True):
#         new_modules[key] = instantiate(value)
#     return new_modules


def freeze(model):
    for param in model.parameters():
        param.requires_grad = False


@hydra.main(version_base=None, config_path='../config', config_name='finetune')
def main(config: DictConfig):

    config_model, model = load_config_and_model(config.model_path)
    OmegaConf.set_struct(config, None)
    OmegaConf.set_struct(config_model, None)

    config = OmegaConf.merge(config_model, config)
    # config.merge_with(config_model)
    # config = hydra.compose(overrides=config)
    
    
    freeze(model)
    model.update(config)
    model.save_hyperparameters(config)

    trainer = instantiate(config.trainer)
    if trainer.global_rank == 0:
        print(OmegaConf.to_yaml(config))

    if config.seed is not None:
        seed_everything(config.seed, workers=True)

    train_loader, val_loaders = get_loaders(config)
    
    fit_kwargs = dict(model=model, 
                      train_dataloaders=train_loader, 
                      val_dataloaders=val_loaders)
    
    last_ckpt = get_last_ckpt(trainer.logger.root_dir) if config.resume else None
    if last_ckpt is not None:
        fit_kwargs['ckpt_path'] = last_ckpt
        print(f'set to load model: {last_ckpt}')
    trainer.fit(**fit_kwargs)

    return model


if __name__ == '__main__':
    main()