import argparse
import hydra
import pickle
import random
import torch
import h5py

import numpy as np
import os 
from omegaconf import DictConfig, OmegaConf
from hrl.hrl_model import HRLModel

import gym
import d4rl
import ast

def load_data_from_h5(file_path: str):
    dataset = {}
    with h5py.File(file_path, 'r') as h5_file:
        observations = np.array(h5_file['observations'][:]) 
        actions = np.array(h5_file['actions'][:])            
        rewards = np.array(h5_file['rewards'][:])            
        terminals = np.array(h5_file['dones'][:])            

        observations = observations.transpose(0, 3, 1, 2)  

        # goal = np.load("llm_indice.npy",allow_pickle=True)[0]
        # goal = goal.astype(int)
        # n_class = 16
        # goal = np.eye(n_class)[goal]  

    dataset["observations"] = observations
    dataset["actions"] = actions
    dataset["rewards"] = rewards
    dataset["terminals"] = terminals

    return dataset

def get_args(cfg: DictConfig):
    cfg.trainer.device = "cuda:0" if torch.cuda.is_available() else "cpu"
    cfg.hydra_base_dir = os.path.join("log", cfg["env"]["name"])
    return cfg

def train(args):
    # create env and dataset

    checkpoint = torch.load(args.segment_model)
    device = args.trainer.device

    args.method = args.model.name
    exp_name = f'{args.project_name}-{args.train_dataset.num_trajectories}-{args.method}'
    args.savepath = f'{args.hydra_base_dir}/{args.savedir}/{exp_name}'

    state_dim = args.env.state_dim
    action_dim = args.env.action_dim
    if isinstance(state_dim, str):
        state_dim = ast.literal_eval(state_dim)

    if isinstance(state_dim, tuple):
        assert not args.trainer.state_il, "Cannot do state imitation learning with an image input"

    option_selector_args = dict(args.option_selector)
    option_selector_args['state_dim'] = state_dim
    option_selector_args['option_dim'] = args.option_dim
    option_selector_args['codebook_dim'] = args.codebook_dim
    state_reconstructor_args = dict(args.state_reconstructor)
    lang_reconstructor_args = dict(args.lang_reconstructor)
    decision_transformer_args = {'state_dim': state_dim,
                                 'action_dim': action_dim,
                                 'option_dim': args.option_dim,
                                 'discrete': args.env.discrete,
                                 'hidden_size': args.dt.hidden_size,
                                 'use_language': args.method == 'vanilla',
                                 'use_options': args.method != 'vanilla',
                                 'option_il': args.dt.option_il,
                                 'predict_q': args.use_iq,
                                 'max_length': 1 if 'option' not in args.method else args.model.K,   # used to be K
                                 'max_ep_len': args.env.eval_episode_factor*1,
                                 'n_layer': args.dt.n_layer,
                                 'n_head': args.dt.n_head,
                                 'activation_function': args.dt.activation_function,
                                 'n_positions': args.dt.n_positions,
                                 'n_ctx': args.dt.n_positions,
                                 'resid_pdrop': args.dt.dropout,
                                 'attn_pdrop': args.dt.dropout,
                                 'no_states': args.dt.no_states,
                                 'no_actions': args.dt.no_actions,
                                 }
    hrl_model_args = dict(args.model)
    iq_args = args.iq

    
    if "Grid" in args.env.name:
        env = gym.make(args.env.name)
        with open(os.path.join(args.train_dataset.expert_location, "trajectory_data.pkl"), "rb") as file:
            dataset = pickle.load(file)
    elif "kitchen" in args.env.name:
        env = gym.make(args.env.name)
        dataset = env.get_dataset()
    
    elif "crafter" in args.env.name:
        dataset = load_data_from_h5(args.train_dataset.expert_location)
    
    ratio = args.ratio
    end_terminal = int(ratio * len(np.where(dataset["terminals"]==1)[0]))
    end = np.where(dataset["terminals"]==1)[0][end_terminal-1]
    dataset  = {k: v[:end] for k, v in dataset.items()}

    model = HRLModel(option_selector_args, state_reconstructor_args, 
                     lang_reconstructor_args, decision_transformer_args, iq_args, device, state_dim=state_dim, 
                     action_dim=action_dim, **hrl_model_args)


    model_checkpoint = {}
    for key, value in checkpoint['model'].items():
        if key.startswith("module"):
            new_key = key.replace("module.", "")
        else:
            new_key = key
        model_checkpoint[new_key] = value

    checkpoint["model"] = model_checkpoint    
    model.load_state_dict(checkpoint['model'])
    model = model.to(args.trainer.device)
    states = dataset["observations"]
    states = torch.tensor(states).float().to(args.trainer.device)
    if isinstance(model, torch.nn.DataParallel):
        option_selector = model.module.option_selector
        bc_policy = model.module.bc_policy
    else:
        option_selector = model.option_selector
        bc_policy = model.bc_policy
    vq = option_selector.Z
    project_out = option_selector.Z.project_out
    embed = option_selector.Z._codebook.embed

    state_embeddings = option_selector.embed_state(states)
    inp = torch.cat([state_embeddings], dim=-1)
    option_preds = option_selector.pred_options(inp)
    embedding, vq_indice, loss = vq(option_preds)
    if 'antmaze' in args.env.name:
        dataset["rewards"] = (dataset["rewards"] - 0.5) * 4.0
    dataset["latent_action"] = vq_indice.detach().cpu()

    with h5py.File(os.path.join(args.save_path, args.env.name + ".h5"), 'w') as h5_file:
        for key, value in dataset.items():
            if isinstance(value, np.ndarray):
                h5_file.create_dataset(key, data=value)
            else:
                h5_file.create_dataset(key, data=np.array(value.detach().cpu()))
    return dataset, embed, project_out, bc_policy


def load_model(cfg: DictConfig):
    args = get_args(cfg)
    # set seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device(args.trainer.device)
    if device.type == 'cuda' and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    dataset, embed, project_out, bc_policy = train(args)
    return dataset, embed, project_out, bc_policy, args
    
if __name__ == "__main__":
    load_model()