"""
Training script for e-CRT baseline
Uses your Hydra experiment pattern
"""

import torch
import hydra
from trainer.davt_trainer import DAVTTrainer
from omegaconf import DictConfig, OmegaConf
import wandb
from hydra.utils import instantiate

OmegaConf.register_new_resolver(
    "join",
    lambda sep, xs: sep.join(str(x) for x in xs)
)

@hydra.main(config_path='configs', config_name='config', version_base='1.1')
def train_davt_pipeline(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    if not cfg.wandb.disabled:
        wandb.init(
            project=cfg.wandb.get('project', 'seq-kci'),
            group=cfg.wandb.group,
            name=f"davt_{cfg.data.type}_dseed-{cfg.data.data_seed}_tseed-{cfg.train.seed}",
            tags=cfg.wandb.tags,
            config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
        )
    else:
        wandb.init(mode="disabled")

    # Initialize data generator
    datagen = instantiate(cfg.data)

    # Initialize device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Initialize DAVT trainer
    print("Using DAVT trainer")
    trainer = DAVTTrainer(cfg.train, datagen, device)

    # Run training
    trainer.train()
    wandb.finish()


if __name__ == "__main__":
    train_davt_pipeline()