from algos import QValueTransformer, VValueTransformer, Transformer
from envs import generate_darkroom_goals, generate_darkroom_env
from trainers import NormalDatasetBatch, DiTPolicyModelTrainer, ValueModelTrainer
from ctrls import DarkroomTransformerController

import hydra
from omegaconf import DictConfig, OmegaConf

import torch
import os
import wandb

from utils import seed_everything
from utils_eval import online_evaluate_policy, offline_evaluate_policy

import datetime

import pickle


test_dataset_path = ["./datasets/standard_trajs_darkroom_step_envs100000_rp0.2_hists1_samples1_H100_d10_test.pkl"]
trajs = []
goals = set()
for p in test_dataset_path:
    with open(p, 'rb') as f:
        trajs += pickle.load(f)

for traj in trajs:
    goals.add(tuple(traj["goal"]))
goals = list(goals)

# _, darkroom_test_goals = generate_darkroom_goals(dim=10, split=0.8, seed=0)

timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
experiment_name = "QV_Test_Experiment" # This could be customized later
checkpoint_dir = f"./checkpoints/{experiment_name}_{timestamp}"
save_checkpoints = True

dataset_path = "./datasets/standard_trajs_darkroom_step_envs100000_rp0.2_hists1_samples1_H100_d10_train.pkl" # This could be customized later
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@hydra.main(config_path="policy_conf", config_name="config", version_base=None)
def train_policy_model(cfg: DictConfig) -> None:
    """
    Main training function using Hydra for the reward model.

    Args:
        cfg (DictConfig): Configuration object containing all necessary parameters.
    """
    
    wandb.init(
        project="QV_Test_Experiment",  # required
        name="num_trajs-20000_batch_size-32",              # optional: give your run a name
        config=OmegaConf.to_container(cfg, resolve=True)
    )
    
    print(f"Seed is {cfg.seed}")
    # checkpoint_dir = checkpoint_dir + f"_seed_{cfg.seed}"
    seed_everything(cfg.seed)
    if save_checkpoints:
        # checkpoint_dir += f"_seed_{cfg.seed}"
        os.makedirs(checkpoint_dir, exist_ok=True)
        print(f"Checkpoint directory: {checkpoint_dir}")
    
    policy_model = Transformer(cfg.policy_model).to(device)
    v_model = VValueTransformer(cfg.v_model, gamma=cfg.v_model.gamma).to(device)
    q_model = QValueTransformer(cfg.q_model, gamma=cfg.q_model.gamma).to(device)
    model = {"policy": policy_model, "v": v_model, "q": q_model}
    
    policy_optimizer = torch.optim.AdamW(policy_model.parameters(), lr=cfg.policy_model.learning_rate, weight_decay=1e-4)
    v_optimizer = torch.optim.AdamW(v_model.parameters(), lr=cfg.v_model.learning_rate, weight_decay=1e-4)
    q_optimizer = torch.optim.AdamW(q_model.parameters(), lr=cfg.q_model.learning_rate, weight_decay=1e-4)
    optimizer = {"policy": policy_optimizer, "v": v_optimizer, "q": q_optimizer}
    
    
    value_trainer = ValueModelTrainer(model, optimizer)
    policy_trainer = DiTPolicyModelTrainer(model, optimizer)
    
    
    dataset = NormalDatasetBatch(dataset_path, cfg.dataset, device, num_trajectory=20000) # We actually don't need gamma or num_pairs here
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, shuffle=False)
    for epoch in range(cfg.value_epochs):
        train_loss = value_trainer.train_epoch(dataloader)
        print(f"Epoch {epoch+1}/{cfg.value_epochs}, Train Loss: {train_loss}")
        if (epoch + 1) % cfg.save_interval == 0:
            value_trainer.save_checkpoint(os.path.join(checkpoint_dir, f"value_model_epoch_{epoch+1}.pth"))
            print(f"Value Model saved at epoch {epoch+1}")
            
    for epoch in range(cfg.policy_epochs):
        train_loss = policy_trainer.train_epoch(dataloader)
        print(f"Epoch {epoch+1}/{cfg.policy_epochs}, Train Loss: {train_loss}")
        model["policy"].eval()
        online_eval_reward = online_evaluate_policy(goals, DarkroomTransformerController(model["policy"]), cfg.horizon, n_episodes=cfg.eval_episodes)
        offline_eval_reward = offline_evaluate_policy(goals, DarkroomTransformerController(model["policy"]), cfg.horizon, n_episodes=cfg.eval_episodes)
        model["policy"].train()
        if (epoch + 1) % cfg.save_interval == 0:
            policy_trainer.save_checkpoint(os.path.join(checkpoint_dir, f"policy_model_epoch_{epoch+1}.pth"))
            print(f"Policy Model saved at epoch {epoch+1}")
        
        if cfg.wandb:
            offline_max = 0
            offline_mean = 0
            for key, value in offline_eval_reward.items():
                offline_max += value[0]
                offline_mean += value[1]
            offline_max /= len(offline_eval_reward)
            offline_mean /= len(offline_eval_reward)
    
            online_max = 0
            online_mean = 0
            for key, value in online_eval_reward.items():
                online_max += value[0]
                online_mean += value[1]
            online_max /= len(online_eval_reward)
            online_mean /= len(online_eval_reward)
            
            wandb.log({"train_loss": train_loss, 
                       "offline_evaluated_return_max": offline_max,
                       "offline_evaluated_return_mean": offline_mean, 
                       "online_evaluated_return_max": online_max,
                       "online_evaluated_return_mean": online_mean})
            
if __name__ == "__main__":
    train_policy_model()