from algos import QValueTransformer, VValueTransformer, Transformer, RewardTransformer
from envs import generate_darkroom_goals, generate_darkroom_env
from trainers import PreferenceDiTPolicyModelTrainer, PreferenceValueModelTrainer, PreferenceDatasetBatch
from ctrls import MetaWorldTransformerController

import hydra
from hydra import compose, initialize_config_dir
from omegaconf import DictConfig, OmegaConf

import torch
import os
import wandb

from utils import seed_everything
from utils_metaworld import online_evaluate_policy_with_preference, offline_evaluate_policy_with_preference

import datetime

import pickle

import random

from collections import defaultdict, OrderedDict

import numpy as np

# wandb.init(
#     project="In-Context-RL",     # change to your project name
#     name="Preference-Based-Training-Larger-Model-Size",           # optional: name for this run
#     config={                       # optional: hyperparameters
#         "num_pairs": 5000,
#         "loss_function": "Average Loss",
#         "batch_size": 64,
#         "learning_rate": 1e-3,
#         "Dropout": 0.2
#     }
# )


test_dataset_path = ["./datasets/preference_p80.pkl"]
        
trajs = []
for p in test_dataset_path:
    with open(p, 'rb') as f:
        trajs += pickle.load(f)

test_trajs = defaultdict(list)
for traj in trajs:
    test_trajs[traj["task_id"]].append(traj)


optimal_dataset_path = ["./datasets/preference_p20.pkl"]
        
trajs = []
for p in optimal_dataset_path:
    with open(p, 'rb') as f:
        trajs += pickle.load(f)

optimal_trajs = defaultdict(list)
for traj in trajs:
    optimal_trajs[traj["task_id"]].append(traj)


timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
experiment_name = "Preference_experiment" # This could be customized later
checkpoint_dir = f"./checkpoints/{experiment_name}_{timestamp}"
save_checkpoints = True

dataset_path = "./datasets/preference_train.pkl" # This could be customized later
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with initialize_config_dir(config_dir=os.path.abspath("./reward_conf")):
    reward_cfg = compose(config_name="config", overrides=["seed=0"])

@hydra.main(config_path="policy_conf", config_name="config", version_base=None)
def train_policy_model(cfg: DictConfig) -> None:
    """
    Main training function using Hydra for the reward model.

    Args:
        cfg (DictConfig): Configuration object containing all necessary parameters.
    """
    
    # wandb.init(
    #     project="Metaworld_Preference_Experiment",  # required
    #     name="num_trajs-5000_batch_size-32",              # optional: give your run a name
    #     config=OmegaConf.to_container(cfg, resolve=True)
    # )
    
    
    seed_everything(cfg.seed)
    if save_checkpoints:
        os.makedirs(checkpoint_dir, exist_ok=True)
        print(f"Checkpoint directory: {checkpoint_dir}")
    
    policy_model = Transformer(cfg.policy_model).to(device)
    v_model = VValueTransformer(cfg.v_model).to(device)
    q_model = QValueTransformer(cfg.q_model).to(device)
    reward_model = RewardTransformer(reward_cfg.reward_model).to(device)
    reward_model_state = torch.load("./checkpoints/Reward_Model_experiment_2025-05-14_18-03-32/reward_model_epoch_1.pth")
    # print("2025-05-14_18-03-32")
    # reward_model_state = torch.load("./checkpoints/Reward_Model_experiment_2025-05-14_20-20-10/reward_model_epoch_5.pth")
    
    new_state_dict = OrderedDict()
    for k, v in reward_model_state["model_state_dict"].items():
        new_key = k.replace("module.", "")  # Strip the prefix
        new_state_dict[new_key] = v
    
    
    reward_model.load_state_dict(new_state_dict)
    
    model = {"policy": policy_model, "v": v_model, "q": q_model}
    
    policy_optimizer = torch.optim.AdamW(policy_model.parameters(), lr=cfg.policy_model.learning_rate, weight_decay=1e-4)
    v_optimizer = torch.optim.AdamW(v_model.parameters(), lr=cfg.v_model.learning_rate,  weight_decay=1e-4)
    q_optimizer = torch.optim.AdamW(q_model.parameters(), lr=cfg.q_model.learning_rate,  weight_decay=1e-4)
    optimizer = {"policy": policy_optimizer, "v": v_optimizer, "q": q_optimizer}
    
    n_episodes = 10
    value_trainer = PreferenceValueModelTrainer(model, optimizer, preference_model=reward_model)
    policy_trainer = PreferenceDiTPolicyModelTrainer(model, optimizer, preference_model=reward_model)
    policy_trainer.load_checkpoint(os.path.join("checkpoints", "Preference_experiment_2025-05-15_09-06-49", "policy_model_epoch_30.pth"))
    # policy_trainer.load_checkpoint(os.path.join("checkpoints", "Preference_experiment_2025-05-15_07-45-48", "policy_model_epoch_30.pth"))
    offline_eval_reward, reward_li = offline_evaluate_policy_with_preference(test_trajs,
                                                                  optimal_trajs,
                                                                  MetaWorldTransformerController(model["policy"]), 
                                                                  cfg.horizon, 
                                                                  n_episodes=n_episodes,
                                                                  preference_model=reward_model)
    
    offline_max = 0
    offline_mean = 0
    for key, value in offline_eval_reward.items():
        offline_max += value[0]
        offline_mean += value[1]
    offline_max /= len(offline_eval_reward)
    offline_mean /= len(offline_eval_reward)
    print(offline_max, offline_mean)
    print(np.mean(reward_li), np.std(reward_li))
    breakpoint()
    
    
    dataset = PreferenceDatasetBatch(dataset_path, cfg.dataset, device, num_pairs=5000) # We actually don't need gamma or num_pairs here
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=PreferenceDatasetBatch.batch_collate_fn)
    for epoch in range(cfg.value_epochs):
        train_loss = value_trainer.train_epoch(dataloader)
        print(f"Epoch {epoch+1}/{cfg.value_epochs}, Train Loss: {train_loss}")
        if (epoch + 1) % cfg.save_interval == 0:
            value_trainer.save_checkpoint(os.path.join(checkpoint_dir, f"value_model_epoch_{epoch+1}.pth"))
            print(f"Model saved at epoch {epoch+1}")
    # value_trainer.load_checkpoint(os.path.join("checkpoints", "Preference_experiment_2025-05-10_16-54-16", "value_model_epoch_20.pth"))
    for epoch in range(cfg.policy_epochs):
        train_loss = policy_trainer.train_epoch(dataloader)
        print(f"Epoch {epoch+1}/{cfg.policy_epochs}, Train Loss: {train_loss}")
        model["policy"].eval()
        online_eval_reward = online_evaluate_policy_with_preference(test_trajs,
                                                                    optimal_trajs,
                                                                    MetaWorldTransformerController(model["policy"]), 
                                                                    cfg.horizon, 
                                                                    n_episodes=cfg.eval_episodes,
                                                                    preference_model=reward_model)
        offline_eval_reward = offline_evaluate_policy_with_preference(test_trajs,
                                                                      optimal_trajs,
                                                                      MetaWorldTransformerController(model["policy"]), 
                                                                      cfg.horizon, 
                                                                      n_episodes=cfg.eval_episodes,
                                                                      preference_model=reward_model)
        #offline_eval_reward = offline_evaluate_policy(trajs, DarkroomTransformerController(model["policy"]), cfg.horizon, 10, n_episodes=cfg.eval_episodes)
        # print(offline_eval_reward)
        model["policy"].train()
        if (epoch + 1) % cfg.save_interval == 0:
            policy_trainer.save_checkpoint(os.path.join(checkpoint_dir, f"policy_model_epoch_{epoch+1}.pth"))
            print(f"Model saved at epoch {epoch+1}")
        
        if cfg.wandb:
            offline_max = 0
            offline_mean = 0
            for key, value in offline_eval_reward.items():
                offline_max += value[0]
                offline_mean += value[1]
            offline_max /= len(offline_eval_reward)
            offline_mean /= len(offline_eval_reward)
    
            online_max = 0
            online_mean = 0
            for key, value in online_eval_reward.items():
                online_max += value[0]
                online_mean += value[1]
            online_max /= len(online_eval_reward)
            online_mean /= len(online_eval_reward)
            
            wandb.log({"train_loss": train_loss, 
                       "offline_evaluated_return_max": offline_max,
                       "offline_evaluated_return_mean": offline_mean, 
                       "online_evaluated_return_max": online_max,
                       "online_evaluated_return_mean": online_mean})
            
            
if __name__ == "__main__":
    train_policy_model()