import argparse
import random

import gym
import d4rl
import hydra

import numpy as np
import torch

from omegaconf import DictConfig, OmegaConf
from buffer import ReplayBuffer
from logger import Logger
from trainer import MFPolicyTrainer
from agent import CQLAgent
import torch.optim as optim

from offline_rl.cql.utils import qlearning_dataset
from offline_rl.cql.save_data_z import load_model
from offline_rl.cql.networks import MLP, ActorProb, Critic
from offline_rl.cql.module import DiagGaussian
from offline_rl.cql.agent import BasePolicy, BCPolicy, bc_goal_policy, IQLPolicy

import os
import h5py
import torch.nn as nn

"""
suggested hypers
cql-weight=5.0, temperature=1.0 for all D4RL-Gym tasks
"""
def load_z_dataset(dataset, env):
    # data_path = os.path.join(data_dir, env+ ".h5")

    # z_dataset = {} 
    # with h5py.File(data_path, 'r') as h5_file:
    #     for key in h5_file.keys():
    #         z_dataset[key] = h5_file[key][()]
    z_dataset = qlearning_dataset(env = env, dataset = dataset)

    return z_dataset


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--algo-name", type=str, default="cql")
    parser.add_argument("--load_high_policy", type=str, default=None)
    parser.add_argument("--eval_render", type=bool, default=False)
    parser.add_argument("--task", type=str, default="kitchen-complete-v0") # also change the env name in cfg
    parser.add_argument("--seed", type=int, default=311)
    parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256, 256])
    parser.add_argument("--actor-lr", type=float, default=1e-4)
    parser.add_argument("--critic-lr", type=float, default=3e-4)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--tau", type=float, default=0.005)
    parser.add_argument("--alpha", type=float, default=0.2)
    parser.add_argument("--target-entropy", type=int, default=None)
    parser.add_argument("--auto-alpha", default=True)
    parser.add_argument("--alpha-lr", type=float, default=1e-4)

    parser.add_argument("--cql-weight", type=float, default=5.0)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--max-q-backup", type=bool, default=False)
    parser.add_argument("--deterministic-backup", type=bool, default=True)
    parser.add_argument("--with-lagrange", type=bool, default=False)
    parser.add_argument("--lagrange-threshold", type=float, default=10.0)
    parser.add_argument("--cql-alpha-lr", type=float, default=3e-4)
    parser.add_argument("--num-repeat-actions", type=int, default=10)

    # iql 
    parser.add_argument("--iql-hidden-dims", type=int, nargs='*', default=[256, 256])
    parser.add_argument("--iql-actor-lr", type=float, default=3e-4)
    parser.add_argument("--critic-q-lr", type=float, default=3e-4)
    parser.add_argument("--critic-v-lr", type=float, default=3e-4)
    parser.add_argument("--dropout_rate", type=float, default=0.1)
    parser.add_argument("--lr-decay", type=bool, default=True)
    parser.add_argument("--iql-gamma", type=float, default=0.99)
    parser.add_argument("--iql-tau", type=float, default=0.005)
    parser.add_argument("--expectile", type=float, default=0.7)
    parser.add_argument("--iql-temperature", type=float, default=0.5)


    parser.add_argument("--epoch", type=int, default=int(1000))
    parser.add_argument("--step-per-epoch", type=int, default=100000)
    parser.add_argument("--eval_episodes", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    return parser.parse_args()

@hydra.main(config_path="conf", config_name="save_config")
def train(cfg: DictConfig):
    # create env and dataset
    args=get_args()
    env = gym.make(args.task)
    dataset, embed, project_out, bc_policy, load_args = load_model(cfg)
    dataset = load_z_dataset(dataset, args.task)
    # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22
    if 'antmaze' in args.task:
        dataset["rewards"] = (dataset["rewards"] - 0.5) * 4.0
    args.obs_shape = env.observation_space.shape
    args.action_dim = np.unique(dataset["latent_action"], axis=0).shape[0] # np.unique(dataset["latent_action"], axis=1, return_inverse=True)[1].shape[0]
    ## for discrete action 
    embed_index_set, dataset["latent_action"] = np.unique(dataset["latent_action"], axis = 0, return_inverse = True)
    embed_index_set = embed_index_set.astype("int64")
    dataset["latent_action"] = dataset["latent_action"].astype("int64")
    # dataset["latent_action"] = np.unique(dataset["latent_action"], axis = 0, return_inverse = True)[1].astype("int64")
    # bc_action = latent_action[indices.flatten()] np.unique(dataset["latent_action"], axis = 0)[indices.flatten()]
    # embed_index_set = np.unique(dataset["latent_action"], axis = 0, return_inverse = True)[0].astype("int64")
    # recovered_action = indices[dataset["latent_action"]]

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    env.seed(args.seed)

    # create policy model
    if args.auto_alpha:
        target_entropy = args.target_entropy if args.target_entropy \
            else -np.prod(args.action_dim)

        args.target_entropy = target_entropy

        log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
        alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
        alpha = (target_entropy, log_alpha, alpha_optim)
    else:
        alpha = args.alpha

    # create policy
    high_policy = CQLAgent(
        state_size=args.obs_shape,
        action_size=args.action_dim,
        device = args.device,
        tau=args.tau,
        gamma=args.gamma,
        alpha=alpha,
        cql_weight=args.cql_weight,
        temperature=args.temperature,
        max_q_backup=args.max_q_backup,
        deterministic_backup=args.deterministic_backup,
        with_lagrange=args.with_lagrange,
        lagrange_threshold=args.lagrange_threshold,
        cql_alpha_lr=args.cql_alpha_lr,
        num_repeat_actions=args.num_repeat_actions)

    if args.load_high_policy is not None:
        high_policy_network = torch.load(args.load_high_policy)
        high_policy.network.load_state_dict(high_policy_network)
    
    actor_backbone = MLP(input_dim=np.prod(args.obs_shape[0] + load_args["option_dim"]), hidden_dims=args.iql_hidden_dims, dropout_rate=args.dropout_rate)
    # actor_backbone = MLP(input_dim=np.prod(args.obs_shape[0]), hidden_dims=args.iql_hidden_dims, dropout_rate=args.dropout_rate)

    critic_q1_backbone = MLP(input_dim=np.prod(args.obs_shape[0])+load_args.env.action_dim, hidden_dims=args.iql_hidden_dims)
    critic_q2_backbone = MLP(input_dim=np.prod(args.obs_shape[0])+load_args.env.action_dim, hidden_dims=args.iql_hidden_dims)
    critic_v_backbone = MLP(input_dim=np.prod(args.obs_shape[0]), hidden_dims=args.iql_hidden_dims)
    dist = DiagGaussian(
        latent_dim=getattr(actor_backbone, "output_dim"),
        output_dim=load_args.env.action_dim,
        unbounded=False,
        conditioned_sigma=False
        # max_mu=args.max_action
    )
    actor = ActorProb(actor_backbone, dist, args.device)
    critic_q1 = Critic(critic_q1_backbone, args.device)
    critic_q2 = Critic(critic_q2_backbone, args.device)
    critic_v = Critic(critic_v_backbone, args.device)
    
    for m in list(actor.modules()) + list(critic_q1.modules()) + list(critic_q2.modules()) + list(critic_v.modules()):
        if isinstance(m, torch.nn.Linear):
            # orthogonal initialization
            torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
            torch.nn.init.zeros_(m.bias)

    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.iql_actor_lr)
    critic_q1_optim = torch.optim.Adam(critic_q1.parameters(), lr=args.critic_q_lr)
    critic_q2_optim = torch.optim.Adam(critic_q2.parameters(), lr=args.critic_q_lr)
    critic_v_optim = torch.optim.Adam(critic_v.parameters(), lr=args.critic_v_lr)

    # bcgoal_policy = bc_goal_policy(
    #     bc_policy, optim.Adam(params=bc_policy.parameters(), lr=1e-3), device = args.device
    # )

    if args.lr_decay:
        # lr_scheduler = [torch.optim.lr_scheduler.CosineAnnealingLR(critic_q1_optim, args.epoch), \
        #                 torch.optim.lr_scheduler.CosineAnnealingLR(critic_q2_optim, args.epoch), \
        #                 torch.optim.lr_scheduler.CosineAnnealingLR(critic_v_optim, args.epoch)]
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(actor_optim, args.epoch)
    else:
        lr_scheduler = None
    
    # create IQL policy
    low_policy = IQLPolicy(
        actor,
        critic_q1,
        critic_q2,
        critic_v,
        actor_optim,
        critic_q1_optim,
        critic_q2_optim,
        critic_v_optim,
        action_space=env.action_space,
        tau=args.iql_tau,
        gamma=args.iql_gamma,
        expectile=args.expectile,
        temperature=args.iql_temperature
    )

    # create buffer
    buffer = ReplayBuffer(
        buffer_size=len(dataset["observations"]),
        obs_shape=args.obs_shape,
        obs_dtype=np.float32,
        action_dim=args.action_dim,
        action_dtype=np.int64,
        device=args.device
    )
    buffer.load_dataset(dataset)

    # log
    log_dirs = os.path.join("log", args.algo_name)

    logger = Logger(log_dirs, args.task, args.seed)
    logger.log_str_object("parameters", log_dict ={key: str(value) for key, value in {**vars(args), **load_args}.items()})

    # create policy trainer
    policy_trainer = MFPolicyTrainer(
        policy=high_policy,
        eval_env=env,
        buffer=buffer,
        logger=logger,
        render=args.eval_render,
        epoch=args.epoch,
        step_per_epoch=args.step_per_epoch,
        batch_size=args.batch_size,
        eval_episodes=args.eval_episodes,
        low_policy = low_policy,
        embed_codebook = embed,
        project_out = project_out,
        embed_index_set = embed_index_set,
        lr_scheduler=lr_scheduler
    )

    # train
    policy_trainer.train()


if __name__ == "__main__":
    train()