"""
highly based on 
https://github.com/kzl/decision-transformer/blob/master/gym/experiment.py
https://github.com/gwthomas/IQL-PyTorch/blob/main/main.py
https://github.com/kzl/decision-transformer/blob/master/gym/decision_transformer/evaluation/evaluate_episodes.py
"""
import gym
import numpy as np
import torch
import argparse
import pickle
import random
import os
import math
import faiss  
import d4rl   

from typing import Optional

from decision_transformer.evaluation.evaluate_episodes import (
    evaluate_episode_retrieve_iql
)
from decision_transformer.models.decision_transformer_dist import DecisionTransformer_Dist
from src.value_functions import TwinQ   



def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
    return discount_cumsum




def load_iql_q_function(
    env_name, dataset,
    state_dim, act_dim,
    device="cuda",
    hidden_dim=256,
    n_hidden=2,
):
    """
    Load TwinQ(qf):
        saved_iql/hopper-medium-v2/hopper-medium-v2_qf.pt

    env_name: "hopper"
    dataset:  "medium"
    """
    iql_env_name = f"{env_name}-{dataset}-v1"   # hopper-medium-v2  maze2d-medium-v1
    q_dir = os.path.join("saved_iql", iql_env_name)
    q_path = os.path.join(q_dir, f"{iql_env_name}_qf.pt")

    assert os.path.exists(q_path), f"can't find IQL qf : {q_path}"

    qf = TwinQ(
        state_dim=state_dim,
        action_dim=act_dim,
        hidden_dim=hidden_dim,
        n_hidden=n_hidden,
    ).to(device)

    ckpt = torch.load(q_path, map_location=device)
    qf.load_state_dict(ckpt)
    qf.eval()

    print(f" Load IQL TwinQ: {q_path}")
    return qf



def load_dt_and_retrieval(env_name, dataset, state_dim, act_dim,
                          K, embed_dim, n_layer, n_head, dropout,
                          device="cuda"):
    """
    Load:
      1) Model: saved_models/{env}_{dataset}_dt_dist/best.pt
      2) Retrieve: saved_codes/code_{env}_{dataset}/ index / actions / rtgs

    Return:
      model: Load DecisionTransformer_Dist
      faiss_index: FAISS Index
      faiss_actions: np.ndarray [N, act_dim]
      faiss_rtgs: np.ndarray [N, ...]
    """

    model_path = f"saved_models_maze/{env_name}_{dataset}_dt_dist/best.pt"
    assert os.path.exists(model_path), f"model don't exist: {model_path}"
    if env_name == "hopper":
        env_targets = [3600, 1800]
        max_ep_len = 1000
        scale = 1000.0
    elif env_name == "halfcheetah":
        env_targets = [12000, 6000]
        max_ep_len = 1000
        scale = 1000.0
    elif env_name == "walker2d":
        env_targets = [5000, 2500]
        max_ep_len = 1000
        scale = 1000.0
    elif env_name == "maze2d":
        max_ep_len = 999
        env_targets = [300, 200, 150, 100, 50, 20]
        scale = 10
    else:
        raise NotImplementedError(f"Unsupported env_name: {env_name}")
    

    model = DecisionTransformer_Dist(
        state_dim=state_dim,
        act_dim=act_dim,
        max_length=K,
        max_ep_len=max_ep_len,
        hidden_size=embed_dim,
        n_layer=n_layer,
        n_head=n_head,
        n_inner=4 * embed_dim,
        activation_function="relu",
        n_positions=1024,
        resid_pdrop=dropout,
        attn_pdrop=dropout,
    ).to(device)

    ckpt = torch.load(model_path, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    model.eval()
    print(f" load: {model_path}")

    code_dir = f"saved_codes/code_{env_name}_{dataset}"

    act_file   = os.path.join(code_dir, f"{env_name}_{dataset}_faiss_actions.npz")
    rtg_file   = os.path.join(code_dir, f"{env_name}_{dataset}_faiss_rtgs.npz")
    index_file = os.path.join(code_dir, f"{env_name}_{dataset}_faiss.index")

    assert os.path.exists(act_file),   f"actions don't exist: {act_file}"
    assert os.path.exists(rtg_file),   f"rtgs don't exist: {rtg_file}"
    assert os.path.exists(index_file), f"FAISS index don't exist: {index_file}"

    faiss_actions = np.load(act_file)["actions"].astype(np.float32)
    faiss_rtgs    = np.load(rtg_file)["rtgs"].astype(np.float32)
    faiss_index   = faiss.read_index(index_file)

    print(f" Load FAISS : d={faiss_index.d}, ntotal={faiss_index.ntotal}")
    print(f" actions shape: {faiss_actions.shape}, rtgs shape: {faiss_rtgs.shape}")

    return model, faiss_index, faiss_actions, faiss_rtgs







# 3.retrieve

def retrieve_actions(embedding, faiss_index, actions, rtgs, k=10):
    if isinstance(embedding, torch.Tensor):
        embedding = embedding.detach().cpu().numpy()
    embedding = embedding.reshape(1, -1).astype(np.float32)

    D, I = faiss_index.search(embedding, k)

    retrieved_actions = actions[I[0]]
    retrieved_rtgs    = rtgs[I[0]]

    return retrieved_actions, retrieved_rtgs, D[0], I[0]





def eval_episodes_factory_q(
    env,
    model_type,
    env_targets,
    num_eval_episodes,
    state_dim,
    act_dim,
    max_ep_len,
    scale,
    mode,
    state_mean,
    state_std,
    device,
    qf,
    num_samples=6,
    faiss_index=None,
    faiss_actions=None,
    faiss_rtgs=None,
    K=20,
):
    def make_eval(target_rew):
        def fn(model):
            returns, lengths, scores = [], [], []
            for _ in range(num_eval_episodes):
                with torch.no_grad():
                    if model_type == "dt_dist":
                        ret, length, norm_score = evaluate_episode_retrieve_iql(
                            env,
                            state_dim,
                            act_dim,
                            model,
                            qf,
                            faiss_index=faiss_index,
                            faiss_actions=faiss_actions,
                            faiss_rtgs = faiss_rtgs,
                            K=K,                 
                            num_samples=5,           
                            num_retrieved=5,          
                            max_ep_len=max_ep_len,
                            scale=scale,
                            target_return=target_rew/scale,
                            mode=mode,
                            state_mean=state_mean,
                            state_std=state_std,
                            device=device,
                        )

                    else:
                        raise NotImplementedError("Q-guided only supports dt_dist")
                returns.append(ret)
                lengths.append(length)
                scores.append(norm_score)
            return {
                f"target_{target_rew}_return_mean": np.mean(returns),
                f"target_{target_rew}_return_std": np.std(returns),
                f"target_{target_rew}_length_mean": np.mean(lengths),
                f"target_{target_rew}_length_std": np.std(lengths),
                f"target_{target_rew}_norm_score_mean": np.mean(scores),
                f"target_{target_rew}_norm_score_std": np.std(scores),
            }

        return fn

    return [make_eval(tar) for tar in env_targets]


def main(args):
    device = args.device

    env_name, dataset = args.env, args.dataset
    model_type = args.model_type

    env_id = f"{env_name.lower()}-{dataset.lower()}-v1" # hopper-meidum-v2  maze2d-medium-v1
    env = gym.make(env_id)
    

    if env_name == "hopper":
        env_targets = [3600, 1800]
        max_ep_len = 1000
        scale = 1000.0
    elif env_name == "halfcheetah":
        env_targets = [12000, 6000]
        max_ep_len = 1000
        scale = 1000.0
    elif env_name == "walker2d":
        env_targets = [5000, 2500]
        max_ep_len = 1000
        scale = 1000.0
    elif env_name == "maze2d":
        max_ep_len = 999
        env_targets = [300, 200, 150, 100, 50, 20]
        scale = 10
    else:
        raise NotImplementedError(f"Unsupported env_name: {env_name}")

    state_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    dataset_path = f"data/{env_name}-{dataset}-v1.pkl"
    assert os.path.exists(dataset_path), f"dataset not found: {dataset_path}"
    with open(dataset_path, "rb") as f:
        trajectories = pickle.load(f)

    mode = args.mode
    states, traj_lens, returns = [], [], []
    for path in trajectories:
        if mode == "delayed":
            path["rewards"][-1] = path["rewards"].sum()
            path["rewards"][:-1] = 0.0
        states.append(path["observations"])
        traj_lens.append(len(path["observations"]))
        returns.append(path["rewards"].sum())
    traj_lens, returns = np.array(traj_lens), np.array(returns)

    states = np.concatenate(states, axis=0)
    state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

    num_timesteps = sum(traj_lens)

    print("=" * 50)
    print(f"[Eval] env: {env_name}, dataset: {dataset}")
    print(f"{len(traj_lens)} trajectories, {num_timesteps} timesteps found")
    print(f"Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}")
    print(f"Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}")
    print("=" * 50)

    K = args.K

    assert model_type == "dt_dist", "only support dt_dist"

    # 1)add retrieval + load DT model
    model, faiss_index, faiss_actions, faiss_rtgs = load_dt_and_retrieval(
        env_name=env_name,
        dataset=dataset,
        state_dim=state_dim,
        act_dim=act_dim,
        K=K,
        embed_dim=args.embed_dim,
        n_layer=args.n_layer,
        n_head=args.n_head,
        dropout=args.dropout,
        device=device,
    )

    # 2) load IQL Q 
    qf = load_iql_q_function(
        env_name=env_name,
        dataset=f"{dataset}",   # 'medium'
        state_dim=state_dim,
        act_dim=act_dim,
        device=device,
    )

    # 3) use Q-guided policy to eval DT-Dist
    eval_fns_q = eval_episodes_factory_q(
        env=env,
        model_type=model_type,
        env_targets=env_targets,
        num_eval_episodes=args.num_eval_episodes,
        state_dim=state_dim,
        act_dim=act_dim,
        max_ep_len=max_ep_len,
        scale=scale,
        mode=mode,
        state_mean=state_mean,
        state_std=state_std,
        device=device,
        qf=qf,
        num_samples=6,
        K=K,
        faiss_index=faiss_index,
        faiss_actions=faiss_actions,
        faiss_rtgs= faiss_rtgs
    )

    all_results_q = {}
    for fn in eval_fns_q:
        res = fn(model)
        all_results_q.update(res)
 
    for k, v in all_results_q.items():
        print(f"{k}: {v}")



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", type=str, default="hopper")
    parser.add_argument("--dataset", type=str, default="medium")
    parser.add_argument("--mode", type=str, default="normal")  # normal / delayed
    parser.add_argument("--K", type=int, default=20)
    parser.add_argument("--model_type", type=str, default="dt_dist")  # dt / dt_dist / bc
    parser.add_argument("--embed_dim", type=int, default=128)
    parser.add_argument("--n_layer", type=int, default=3)
    parser.add_argument("--n_head", type=int, default=1)
    parser.add_argument("--activation_function", type=str, default="relu")
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--num_eval_episodes", type=int, default=10)
    parser.add_argument("--device", type=str, default="cuda")

    args = parser.parse_args()
    main(args)


