from algos import QValueTransformer, VValueTransformer, PreferenceTransformer
from trainers import NormalDatasetBatch, DiTPolicyModelTrainer, ValueModelTrainer, PreferenceDatasetBatch
from ctrls import DarkroomTransformerController

import hydra
from omegaconf import DictConfig, OmegaConf

import torch
import os
import wandb

from utils import seed_everything
from utils_metaworld import online_evaluate_policy_DPT
import datetime

import pickle

import tqdm
from collections import defaultdict

import numpy as np


test_dataset_path = ["./datasets/preference_p20.pkl"]
trajs = []
for p in test_dataset_path:
    with open(p, 'rb') as f:
        trajs += pickle.load(f)

test_trajs = defaultdict(list)
for traj in trajs:
    test_trajs[traj["task_id"]].append(traj)


# _, darkroom_test_goals = generate_darkroom_goals(dim=10, split=0.8, seed=0)

timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
experiment_name = "QV_Test_Experiment" # This could be customized later
checkpoint_dir = f"./checkpoints/{experiment_name}_{timestamp}"
save_checkpoints = False

dataset_path = "./datasets/preference_DPT_train.pkl" # This could be customized later
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@hydra.main(config_path="policy_conf", config_name="config", version_base=None)
def train_policy_model(cfg: DictConfig) -> None:
    """
    Main training function using Hydra for the reward model.

    Args:
        cfg (DictConfig): Configuration object containing all necessary parameters.
    """
    
    # wandb.init(
    #     project="DPT_Experiment_Metaworld",  # required
    #     name="num_trajs-50000_batch_size-32-0.8-context-explore",              # optional: give your run a name
    #     config=OmegaConf.to_container(cfg, resolve=True)
    # )
    
    print(f"Seed is {cfg.seed}")
    # checkpoint_dir = checkpoint_dir + f"_seed_{cfg.seed}"
    seed_everything(cfg.seed)
    if save_checkpoints:
        # checkpoint_dir += f"_seed_{cfg.seed}"
        os.makedirs(checkpoint_dir, exist_ok=True)
        print(f"Checkpoint directory: {checkpoint_dir}")
    
    policy_model = PreferenceTransformer(cfg.policy_model).to(device)
    
    policy_optimizer = torch.optim.AdamW(policy_model.parameters(), lr=cfg.policy_model.learning_rate, weight_decay=1e-4)
    
    dataset = PreferenceDatasetBatch(dataset_path, cfg.dataset, device, num_pairs=50000) # We actually don't need gamma or num_pairs here
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=PreferenceDatasetBatch.batch_collate_fn)
    
    state = torch.load(os.path.join("./checkpoints/QV_Test_Experiment_2025-05-15_23-54-31/policy_model_epoch_40.pth"))
    policy_model.load_state_dict(state["policy_state_dict"])
    policy_model.eval()
    online_eval_reward, reward_li = online_evaluate_policy_DPT(test_trajs, policy_model, cfg.horizon, n_episodes=10)
    
    online_max = 0
    online_mean = 0
    for key, value in online_eval_reward.items():
        online_max += value[0]
        online_mean += value[1]
    online_max /= len(online_eval_reward)
    online_mean /= len(online_eval_reward)
    print(online_max, online_mean)
    print(np.mean(reward_li), np.std(reward_li))
    policy_model.train()
    breakpoint()
    
    for epoch in range(cfg.policy_epochs):
        train_loss = 0
        for batch in tqdm.tqdm(dataloader, desc="Training Epoch"):
            pred_action = policy_model(batch, batch["query_states"])
            optimal_action = batch["optimal_actions"]
            optimal_action = optimal_action.unsqueeze(1).repeat(1, 200, 1)
            pred_action = pred_action.reshape(-1, 5)
            optimal_action = optimal_action.reshape(-1, 5)
            loss = torch.nn.functional.mse_loss(pred_action, optimal_action)
            policy_optimizer.zero_grad()
            loss.backward()
            policy_optimizer.step()
            train_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{cfg.policy_epochs}, Train Loss: {train_loss/len(dataloader)}")
        
        state = {}
        state["policy_state_dict"] = policy_model.state_dict()
        torch.save(state, os.path.join(checkpoint_dir, f"policy_model_epoch_{epoch+1}.pth"))
        
        policy_model.eval()
        # online_eval_reward = online_evaluate_policy_with_preference(goals, DarkroomTransformerController(policy_model), cfg.horizon, n_episodes=cfg.eval_episodes)
        # offline_eval_reward = offline_evaluate_policy_with_preference(goals, DarkroomTransformerController(policy_model), cfg.horizon, n_episodes=cfg.eval_episodes)
        online_eval_reward = online_evaluate_policy_DPT(test_trajs, policy_model, cfg.horizon, n_episodes=cfg.eval_episodes)
        policy_model.train()
        
        if cfg.wandb:
    
            online_max = 0
            online_mean = 0
            for key, value in online_eval_reward.items():
                online_max += value[0]
                online_mean += value[1]
            online_max /= len(online_eval_reward)
            online_mean /= len(online_eval_reward)
            
            wandb.log({"train_loss": train_loss/len(dataloader),
                       "online_evaluated_return_max": online_max,
                       "online_evaluated_return_mean": online_mean})
            
if __name__ == "__main__":
    train_policy_model()