from algos import QValueTransformer, VValueTransformer, Transformer, RewardTransformer
from envs import generate_darkroom_goals, generate_darkroom_env
from trainers import PreferenceDiTPolicyModelTrainer, PreferenceValueModelTrainer, PreferenceDatasetBatch
from ctrls import DarkroomTransformerController

import hydra
from hydra import compose, initialize_config_dir
from omegaconf import DictConfig, OmegaConf

import torch
import os
import wandb

from utils import seed_everything
from utils_eval import online_evaluate_policy_with_preference, offline_evaluate_policy_with_preference

import datetime

import pickle

import random


test_dataset_path = ["./datasets/preference_trajs_darkroom_step_envs10000_rp0.2_hists5_samples10_H100_d10_test.pkl"]
        
trajs = []
goals = set()
for p in test_dataset_path:
    with open(p, 'rb') as f:
        trajs += pickle.load(f)

for traj in trajs:
    goals.add(tuple(traj["goal"]))
goals = list(goals)

# _, darkroom_test_goals = generate_darkroom_goals(dim=10, split=0.8, seed=0)

timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
experiment_name = "Preference_experiment" # This could be customized later
checkpoint_dir = f"./checkpoints/{experiment_name}_{timestamp}"
save_checkpoints = True

dataset_path = "./datasets/preference_trajs_darkroom_step_envs10000_rp0.2_hists5_samples10_H100_d10_train.pkl" # This could be customized later
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with initialize_config_dir(config_dir=os.path.abspath("./reward_conf")):
    reward_cfg = compose(config_name="config", overrides=["seed=0"])

@hydra.main(config_path="policy_conf", config_name="config", version_base=None)
def train_policy_model(cfg: DictConfig) -> None:
    """
    Main training function using Hydra for the reward model.

    Args:
        cfg (DictConfig): Configuration object containing all necessary parameters.
    """
    
    wandb.init(
        project="Preference_Experiment",  # required
        name="num_trajs-2000_batch_size-32",              # optional: give your run a name
        config=OmegaConf.to_container(cfg, resolve=True)
    )
    
    
    seed_everything(cfg.seed)
    if save_checkpoints:
        os.makedirs(checkpoint_dir, exist_ok=True)
        print(f"Checkpoint directory: {checkpoint_dir}")
    
    policy_model = Transformer(cfg.policy_model).to(device)
    v_model = VValueTransformer(cfg.v_model).to(device)
    q_model = QValueTransformer(cfg.q_model).to(device)
    reward_model = RewardTransformer(reward_cfg.reward_model).to(device)
    reward_model_state = torch.load("./checkpoints/Reward_Model_experiment_2025-05-12_21-30-20/reward_model_epoch_30.pth")
    reward_model.load_state_dict(reward_model_state["model_state_dict"])
    
    model = {"policy": policy_model, "v": v_model, "q": q_model}
    
    policy_optimizer = torch.optim.AdamW(policy_model.parameters(), lr=cfg.policy_model.learning_rate, weight_decay=1e-4)
    v_optimizer = torch.optim.AdamW(v_model.parameters(), lr=cfg.v_model.learning_rate,  weight_decay=1e-4)
    q_optimizer = torch.optim.AdamW(q_model.parameters(), lr=cfg.q_model.learning_rate,  weight_decay=1e-4)
    optimizer = {"policy": policy_optimizer, "v": v_optimizer, "q": q_optimizer}
    
    
    value_trainer = PreferenceValueModelTrainer(model, optimizer, preference_model=reward_model)
    policy_trainer = PreferenceDiTPolicyModelTrainer(model, optimizer, preference_model=reward_model)
    
    
    dataset = PreferenceDatasetBatch(dataset_path, cfg.dataset, device, num_pairs=5000) # We actually don't need gamma or num_pairs here
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=PreferenceDatasetBatch.batch_collate_fn)
    for epoch in range(cfg.policy_epochs):
        train_loss = value_trainer.train_epoch(dataloader)
        print(f"Epoch {epoch+1}/{cfg.value_epochs}, Train Loss: {train_loss}")
        if (epoch + 1) % cfg.save_interval == 0:
            value_trainer.save_checkpoint(os.path.join(checkpoint_dir, f"value_model_epoch_{epoch+1}.pth"))
            print(f"Model saved at epoch {epoch+1}")
    # value_trainer.load_checkpoint(os.path.join("checkpoints", "Preference_experiment_2025-05-10_16-54-16", "value_model_epoch_20.pth"))
    for epoch in range(cfg.policy_epochs):
        train_loss = policy_trainer.train_epoch(dataloader)
        print(f"Epoch {epoch+1}/{cfg.policy_epochs}, Train Loss: {train_loss}")
        model["policy"].eval()
        online_eval_reward = online_evaluate_policy_with_preference(goals, 
                                                                    DarkroomTransformerController(model["policy"]), 
                                                                    cfg.horizon, 
                                                                    n_episodes=cfg.eval_episodes,
                                                                    preference_model=reward_model)
        offline_eval_reward = offline_evaluate_policy_with_preference(goals, 
                                                                    DarkroomTransformerController(model["policy"]), 
                                                                    cfg.horizon, 
                                                                    n_episodes=cfg.eval_episodes,
                                                                    preference_model=reward_model)
        #offline_eval_reward = offline_evaluate_policy(trajs, DarkroomTransformerController(model["policy"]), cfg.horizon, 10, n_episodes=cfg.eval_episodes)
        # print(offline_eval_reward)
        model["policy"].train()

        
        if cfg.wandb:
            offline_max = 0
            offline_mean = 0
            for key, value in offline_eval_reward.items():
                offline_max += value[0]
                offline_mean += value[1]
            offline_max /= len(offline_eval_reward)
            offline_mean /= len(offline_eval_reward)
    
            online_max = 0
            online_mean = 0
            for key, value in online_eval_reward.items():
                online_max += value[0]
                online_mean += value[1]
            online_max /= len(online_eval_reward)
            online_mean /= len(online_eval_reward)
            
            wandb.log({"train_loss": train_loss, 
                       "offline_evaluated_return_max": offline_max,
                       "offline_evaluated_return_mean": offline_mean, 
                       "online_evaluated_return_max": online_max,
                       "online_evaluated_return_mean": online_mean})
            
            
if __name__ == "__main__":
    train_policy_model()