from algos import RewardTransformer
from trainers import PreferenceDatasetBatch, RewardModelTrainer
import hydra
from omegaconf import DictConfig

import torch
import os
import wandb

from utils import seed_everything

import datetime

timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
experiment_name = "Reward_Model_experiment" # This could be customized later
checkpoint_dir = f"./checkpoints/{experiment_name}_{timestamp}"

dataset_path = "./datasets/preference_trajs_darkroom_step_envs10000_rp0.2_hists5_samples10_H100_d10_train.pkl" # This could be customized later
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@hydra.main(config_path="reward_conf", config_name="config", version_base=None)
def train_reward_model(cfg: DictConfig) -> None:
    """
    Main training function using Hydra for the reward model.

    Args:
        cfg (DictConfig): Configuration object containing all necessary parameters.
    """
    seed_everything(cfg.seed)
    os.makedirs(checkpoint_dir, exist_ok=True)
    print(f"Checkpoint directory: {checkpoint_dir}")
    
    model = RewardTransformer(cfg.reward_model).to(device)
    dataset = PreferenceDatasetBatch(dataset_path, cfg.dataset, device, num_pairs=10000) # We actually don't need gamma or num_pairs here
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=PreferenceDatasetBatch.batch_collate_fn)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.reward_model.learning_rate)
    trainer = RewardModelTrainer(model, optimizer)
    
    for epoch in range(cfg.epochs):
        train_loss = trainer.train_epoch(dataloader)
        print(f"Epoch {epoch+1}/{cfg.epochs}, Train Loss: {train_loss:.4f}")
        
        if (epoch + 1) % cfg.save_interval == 0:
            trainer.save_checkpoint(os.path.join(checkpoint_dir, f"reward_model_epoch_{epoch+1}.pth"))
            print(f"Model saved at epoch {epoch+1}")
        
        if cfg.wandb:
            wandb.log({"train_loss": train_loss, "epoch": epoch+1})
            
if __name__ == "__main__":
    train_reward_model()