import os
import pickle
import torch
import numpy as np
import gym
import d4rl

from decision_transformer.models.decision_transformer_dist import DecisionTransformer_Dist
from src.utils import save_dataset_embeddings, build_faiss_index
import argparse



def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Using device:", device)

    args = parse_args()
    env_name = args.env
    dataset = args.dataset
    model_type = args.model_type
    K = 20
    embed_dim = 128
    n_layer = 3
    n_head = 1
    dropout = 0.1
    base_dir = os.path.dirname(os.path.abspath(__file__))
    model_save_dir = os.path.join(base_dir, 'saved_models_maze')
    code_save_dir = os.path.join(base_dir, 'saved_codes')
    code_dir = os.path.join(
        code_save_dir,
        f"code_{env_name}_{dataset}"
    )
    os.makedirs(code_dir, exist_ok=True)
    ckpt_path = os.path.join(
        model_save_dir,
        f'{env_name}_{dataset}_{model_type}',
        'best.pt',
    )
    print("Load Model:", ckpt_path)

    env_id = f"{env_name.lower()}-{dataset.lower()}-v2"  # hopper-medium-replay-v2 maze2d-medium-v1
    env = gym.make(env_id)
    if env_name == 'hopper':
        scale = 1000.
        max_ep_len = 1000
    elif env_name == 'halfcheetah':
        scale = 1000.
        max_ep_len = 1000
    elif env_name == 'walker2d':
        scale = 1000.
        max_ep_len = 1000
    elif env_name == 'maze2d':
        max_ep_len = 999
        scale = 10


    state_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    dataset_path = f'data/{env_name}-{dataset}-v2.pkl'
    with open(dataset_path, 'rb') as f:
        trajectories = pickle.load(f)

    states = np.concatenate([traj['observations'] for traj in trajectories], axis=0)
    state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

    model = DecisionTransformer_Dist(
        state_dim=state_dim,
        act_dim=act_dim,
        max_length=K,
        max_ep_len=max_ep_len,
        hidden_size=embed_dim,
        n_layer=n_layer,
        n_head=n_head,
        n_inner=4 * embed_dim,
        activation_function='relu',
        n_positions=1024,
        resid_pdrop=dropout,
        attn_pdrop=dropout,
    ).to(device)

    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt['model_state'])
    model.eval()
    print("[checkpoint] loaded, iter =", ckpt.get('iter_num'), "step =", ckpt.get('global_step'))

    variant = {
        'K': K,
        'batch_size': 64,           
        'state_dim': state_dim,
        'act_dim': act_dim,
        'max_ep_len': max_ep_len,
        'scale': scale,
        'env': env_name,
        'dataset': dataset,
        'model_save_dir': code_dir,
    }

    save_dataset_embeddings(
        variant=variant,
        model=model,
        trajectories=trajectories,
        state_mean=state_mean,
        state_std=state_std,
        device=device,
    )
    build_faiss_index(variant)

    print("Done. Codes are saved in:", code_dir)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='hopper')
    parser.add_argument('--dataset', type=str, default='medium')
    parser.add_argument('--model_type', type=str, default='dt_dist')
    return parser.parse_args()

if __name__ == "__main__":
    main()
