from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import dsrl
import gymnasium as gym  # noqa
import numpy as np
import pyrallis
import torch
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from pyrallis import field

from osrl.algorithms import CDT, CDTTrainer
from osrl.common.exp_util import load_config_and_model, seed_all
import pickle

@dataclass
class EvalConfig:
    path: str = "/home/fn/OSRL/examples/Mymodel/cdt_no_other_loss_0826.pt"
    returns: List[float] = field(default=[5, 20, 50], is_mutable=True)
    costs: List[float] = field(default=[10, 10, 10], is_mutable=True)
    noise_scale: List[float] = None
    eval_episodes: int = 20
    best: bool = False
    device: str = "cpu"
    threads: int = 4

def get_val_path(path,batch_size=2048):
    num_trajectories = len(path['rewards'])
    # batch_inds = np.random.choice(
    #         np.arange(num_trajectories),
    #         size=batch_size,
    #         replace=True,
    #     )
    paths = []

    actions = []
    costs = []
    next_observations = []
    observations = []
    rewards = []
    terminals = []
    dieds = []

    for i in range(batch_size):
        actions.append(path['actions'][i])
        next_observations.append(path['next_observations'][i])
        observations.append(path['observations'][i])
        terminals.append(path['terminals'][i])
        rewards.append(path['rewards'][i])
        dieds.append(path['dieds'][i])
        costs.append(path['costs'][i])


    paths = dict({'actions': np.array(actions),'next_observations': 
                 np.array(next_observations),'observations': np.array(observations),
                 'rewards': np.array(rewards),'terminals': np.array(terminals),'costs':np.array(costs),
                 'dieds':np.array(dieds)})
        
    return paths

@pyrallis.wrap()
def eval(args: EvalConfig):

    # cfg, model = load_config_and_model(args.path, args.best)

    seed = 10

    seed_all(seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    target_entropy = -2

    dataset_path_val = f'/home/fn/OSRL/examples/train/my_cdt_data_val_noauto.pkl'
    with open(dataset_path_val,'rb') as f:
            data_val = pickle.load(f)
    # model & optimizer & scheduler setup
    cdt_model = CDT(
        state_dim=48,
        action_dim=2,
        max_action=[1,1],
        embedding_dim=128,
        seq_len=10,
        episode_len=300,
        num_layers=3,
        num_heads=8,
        attention_dropout=0.1,
        residual_dropout=0.1,
        embedding_dropout=0.1,
        time_emb=True,
        use_rew=True,
        use_cost=True,
        cost_transform=True,
        add_cost_feat=False,
        mul_cost_feat=False,
        cat_cost_feat=False,
        action_head_layers=1,
        cost_prefix=False,
        stochastic=True,
        init_temperature=0.1,
        target_entropy=target_entropy,
    )
    cdt_model.load_state_dict(torch.load(EvalConfig.path))
    cdt_model.to(args.device)

    trainer = CDTTrainer(cdt_model,
                         #env,
                         reward_scale=0.1,
                         cost_scale=1,
                         cost_reverse=False,
                         device=args.device)

    rets = args.returns
    costs = args.costs
    eval_episodes = 10
    assert len(rets) == len(
        costs
    ), f"The length of returns {len(rets)} should be equal to costs {len(costs)}!"
    for target_ret, target_cost in zip(rets, costs):
        seed_all(seed)
        val_df = get_val_path(data_val)
        agent_action,phy_action,action_ems = trainer.evaluate(val_df,eval_episodes,
                                             target_ret * 0.1,
                                             target_cost * 1)
        
        print(
            f"Physician action {phy_action}, Agent action {agent_action}"
        )


if __name__ == "__main__":
    eval()
