import argparse
import random
import ast 

import gym
import d4rl
import hydra

import os
import h5py
import torch.nn as nn

import numpy as np
import torch

from omegaconf import DictConfig, OmegaConf
from buffer import ReplayBuffer
from logger import Logger
from trainer import MFPolicyTrainer
from agent import CQLAgent
import torch.optim as optim


from gym_minigrid.minigrid import MiniGridEnv
import crafter
from offline_rl.cql.utils import qlearning_dataset
from offline_rl.cql.save_data_z import load_model
from common.network import MLPPolicy
from offline_rl.cql.agent import BasePolicy, BCPolicy, bc_goal_policy, IQLPolicy, bc_goal_categorical_policy
from offline_rl.cql.networks import Actor_categorical
from hrl.env import MinigridWrapper


"""
suggested hypers
cql-weight=5.0, temperature=1.0 for all D4RL-Gym tasks
"""

def load_z_dataset(dataset, env):
    # data_path = os.path.join(data_dir, env+ ".h5")

    # z_dataset = {} 
    # with h5py.File(data_path, 'r') as h5_file:
    #     for key in h5_file.keys():
    #         z_dataset[key] = h5_file[key][()]
    z_dataset = qlearning_dataset(env = env, dataset = dataset)


    return z_dataset


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--algo-name", type=str, default="cql")
    # parser.add_argument("--low_policy_path", type=str,
    parser.add_argument("--load_high_policy", type=str, default=None)
    parser.add_argument("--eval_render", type=bool, default=False)
    parser.add_argument("--task", type=str, default="crafter") # also change the env name in cfg
    parser.add_argument("--seed", type=int, default=206)
    parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256, 256])
    parser.add_argument("--actor-lr", type=float, default=1e-4)
    parser.add_argument("--critic-lr", type=float, default=3e-4)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--tau", type=float, default=0.005)
    parser.add_argument("--alpha", type=float, default=0.2)
    parser.add_argument("--target-entropy", type=int, default=None)
    parser.add_argument("--auto-alpha", default=True)
    parser.add_argument("--alpha-lr", type=float, default=1e-4)

    parser.add_argument("--cql-weight", type=float, default=5.0)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--max-q-backup", type=bool, default=False)
    parser.add_argument("--deterministic-backup", type=bool, default=True)
    parser.add_argument("--with-lagrange", type=bool, default=False)
    parser.add_argument("--lagrange-threshold", type=float, default=10.0)
    parser.add_argument("--cql-alpha-lr", type=float, default=3e-4)
    parser.add_argument("--num-repeat-actions", type=int, default=10)
    
    parser.add_argument("--epoch", type=int, default=int(100))
    parser.add_argument("--step-per-epoch", type=int, default=10000)
    parser.add_argument("--eval_episodes", type=int, default=500) # 500 for others
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    return parser.parse_args()

@hydra.main(config_path="conf", config_name="save_config")
def train(cfg: DictConfig):
    # create env and dataset
    args=get_args()

    if ("Grid" or "kitchen") in args.task:
        env = gym.make(args.task)
        env.seed(123)
        if isinstance(env, MiniGridEnv):
            env = MinigridWrapper(env, num_stack=1, seed = 123)
    elif "crafter" in args.task:
        env = crafter.Env()
    
    dataset, embed, project_out, bc_policy, load_args = load_model(cfg)
    dataset = load_z_dataset(dataset, args.task)
    # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22


    if 'antmaze' in args.task:
        dataset["rewards"] = (dataset["rewards"] - 0.5) * 4.0
    
    state_dim = load_args.env.state_dim
    if isinstance(state_dim, str):
        state_dim = ast.literal_eval(state_dim)

    args.obs_shape = state_dim

    args.action_dim = np.unique(dataset["latent_action"], axis=0).shape[0] # np.unique(dataset["latent_action"], axis=1, return_inverse=True)[1].shape[0]
    ## for discrete action 
    embed_index_set, dataset["latent_action"] = np.unique(dataset["latent_action"], axis = 0, return_inverse = True)
    embed_index_set = embed_index_set.astype("int64")
    dataset["latent_action"] = dataset["latent_action"].astype("int64")
    # dataset["latent_action"] = np.unique(dataset["latent_action"], axis = 0, return_inverse = True)[1].astype("int64")
    # bc_action = latent_action[indices.flatten()] np.unique(dataset["latent_action"], axis = 0)[indices.flatten()]
    # embed_index_set = np.unique(dataset["latent_action"], axis = 0, return_inverse = True)[0].astype("int64")
    # recovered_action = indices[dataset["latent_action"]]

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    env.seed(args.seed)

    # create policy model
    if args.auto_alpha:
        target_entropy = args.target_entropy if args.target_entropy \
            else -np.prod(args.action_dim)

        args.target_entropy = target_entropy

        log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
        alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
        alpha = (target_entropy, log_alpha, alpha_optim)
    else:
        alpha = args.alpha

    # create policy
    policy = CQLAgent(
        state_size=args.obs_shape,
        action_size=args.action_dim,
        device = args.device,
        tau=args.tau,
        gamma=args.gamma,
        alpha=alpha,
        cql_weight=args.cql_weight,
        temperature=args.temperature,
        max_q_backup=args.max_q_backup,
        deterministic_backup=args.deterministic_backup,
        with_lagrange=args.with_lagrange,
        lagrange_threshold=args.lagrange_threshold,
        cql_alpha_lr=args.cql_alpha_lr,
        num_repeat_actions=args.num_repeat_actions)

    if args.load_high_policy is not None:
        high_policy_network = torch.load(args.load_high_policy)
        policy.network.load_state_dict(high_policy_network)

    action_size = env.unwrapped.action_space.n if isinstance(env, MinigridWrapper) else env.action_space.n
    bc_policy = Actor_categorical(state_dim, action_size, args.hidden_dims[0], load_args.option_dim)
    bcgoal_policy = bc_goal_categorical_policy(
        bc_policy, optim.Adam(params=bc_policy.parameters(), lr=1e-3), device = args.device
    )

    # create buffer
    buffer = ReplayBuffer(
        buffer_size=len(dataset["observations"]),
        obs_shape=args.obs_shape,
        obs_dtype=np.float32,
        action_dim=args.action_dim,
        action_dtype=np.int64,
        device=args.device
    )
    buffer.load_dataset(dataset)

    # log
    log_dirs = os.path.join("log", args.algo_name)

    logger = Logger(log_dirs, args.task, args.seed)
    logger.log_str_object("parameters", log_dict ={key: str(value) for key, value in vars(args).items()})

    # create policy trainer
    policy_trainer = MFPolicyTrainer(
        policy=policy,
        eval_env=env,
        buffer=buffer,
        logger=logger,
        render=args.eval_render,
        epoch=args.epoch,
        step_per_epoch=args.step_per_epoch,
        batch_size=args.batch_size,
        eval_episodes=args.eval_episodes,
        low_policy = bcgoal_policy,
        embed_codebook = embed,
        project_out = project_out,
        embed_index_set = embed_index_set,
        deterministic = load_args.deterministic
    )

    # train
    policy_trainer.train()


if __name__ == "__main__":
    train()