import os
from typing import List
import h5py
import torch
import random
import numpy as np
from tqdm import tqdm


SMACV1_ENV_NAMES = ["2c_vs_64zg", "5m_vs_6m", "6h_vs_8z", "corridor"]
SMACV2_ENV_NAMES = [f"{map_name}_{map_mode}" for map_name in ["protoss", "terran", "zerg"] for map_mode in ["5_vs_5", "10_vs_10", "10_vs_11", "20_vs_20", "20_vs_23"]]
MAMUJOCO_ENV_NAMES = ["Hopper-v2", "Ant-v2", "HalfCheetah-v2"]
SMACV1_MODES = ["poor", "medium", "good"]
SMACV2_MODES = ["poor", "medium", "expert"]
MAMUJOCO_MODES = ["medium-replay", "medium", "expert"]

IPL_ALGO = "OMAPL"


def load_mujoco(env_name, agent_conf, mode):
    buffer = np.load(f"../dataset/offline_dataset/{env_name}-{agent_conf}-{mode}.npz", allow_pickle=True)
    buffer = {k: v for k, v in buffer.items()}
    # for k, v in buffer.items():
    #     print(f"{k}: {v.shape}")
        
    states = buffer["states"]
    if len(states.shape) > 3:
        n_batches, n_steps, _, _ = states.shape
        states = states.reshape(n_batches, n_steps, -1)
        buffer["states"] = states
    actions = buffer["actions"]
    if actions.shape[1] == states.shape[1]:
        actions = actions[:, :-1]
        buffer["actions"] = actions
    return buffer


def sample_mujoco(buffer, batch_size=128):
    buffer_size = buffer["dones"].shape[0]
    if buffer_size >= batch_size:
        buffer_ids = np.random.choice(buffer_size, batch_size, replace=False)
    else:
        buffer_ids = np.random.choice(buffer_size, batch_size, replace=True)
    obs = buffer["obs"][buffer_ids]
    states = buffer["states"][buffer_ids]
    actions = buffer["actions"][buffer_ids]
    rewards = buffer["rewards"][buffer_ids]
    dones = buffer["dones"][buffer_ids]
    actives = buffer["actives"][buffer_ids]
    return obs, states, actions, rewards, dones, actives


@torch.no_grad()
def get_continuous_rewards(model, all_obs, all_states, actions, dones, gamma=0.99):
    all_obs = torch.from_numpy(all_obs)
    all_states = torch.from_numpy(all_states)
    actions = torch.from_numpy(actions)
    dones = torch.from_numpy(dones)

    n_agents = all_obs.shape[-2]

    states = all_states[:, :-1]
    obs = all_obs[:, :-1]

    one_hot_agent_id = torch.eye(n_agents).expand(obs.shape[0], obs.shape[1], -1, -1)
    all_one_hot_agent_id = torch.eye(n_agents).expand(all_obs.shape[0], all_obs.shape[1], -1, -1)

    obs = torch.cat((obs, actions, one_hot_agent_id), -1)
    all_obs = torch.cat((all_obs, all_one_hot_agent_id), -1)

    q_values = model.q.forward(obs)
    mw_q, mb_q = model.q_mix_model.forward(states)
    q_values_tot = (mw_q * q_values).sum(-1) + mb_q.squeeze(-1)
    all_v_values = model.v.forward(all_obs)
    mw_all, mb_all = model.q_mix_model.forward(all_states)
    all_v_values = (mw_all * all_v_values).sum(-1) + mb_all.squeeze(-1)
    target_v_values = (1 - dones.float()) * all_v_values[:, 1:]
    rewards = q_values_tot - gamma * target_v_values
    return torch.sum(rewards, -1).numpy()


@torch.no_grad()
def get_discrete_rewards(model, all_obs, all_states, actions, dones, actives, gamma=0.99):
    all_obs = torch.from_numpy(all_obs)
    all_states = torch.from_numpy(all_states)
    actions = torch.from_numpy(actions)
    dones = torch.from_numpy(dones)
    actives = torch.from_numpy(actives)

    n_agents = all_obs.shape[-2]

    agent_ids = torch.eye(n_agents).expand(all_obs.shape[0], all_obs.shape[1], -1, -1)
    all_obs = torch.cat((all_obs, agent_ids), -1)
    
    states = all_states[:, :-1]
    obs = all_obs[:, :-1]

    next_states = all_states[:, 1:]
    next_obs = all_obs[:, 1:]

    q_values = model.q.forward(obs).gather(-1, actions.unsqueeze(-1)).squeeze(-1)
    mw_q, mb_q = model.q_mix_model.forward(states)
    q_values_tot = (mw_q * q_values).sum(-1) + mb_q.squeeze(-1)
    next_v_values = model.v.forward(next_obs)
    mw_next, mb_next = model.q_mix_model.forward(next_states)
    next_v_values = (mw_next * next_v_values).sum(-1) + mb_next.squeeze(-1)
    target_v_values = (1 - dones.float()) * next_v_values
    rewards = q_values_tot - gamma * target_v_values
    rewards = rewards * actives.float()
    return torch.sum(rewards, -1).numpy()


def load_ipl_model(env_name, seed=0):
    path = f"../dataset/saved_models_PreferDICE/{IPL_ALGO}/{env_name}/model_seed{seed}.pt"
    model = torch.jit.load(path, "cpu")
    return model


def create_mujoco_dataset(env_name, modes, ex_episodes=100, im_episodes=100, seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
    model = load_ipl_model(env_name, seed)

    if env_name == "Ant-v2":
        agent_conf = "2x4"
    elif env_name == "HalfCheetah-v2":
        agent_conf = "6x1"
    elif env_name == "Hopper-v2":
        agent_conf = "3x1"
    else:
        raise NotImplementedError
    
    ex_buffer = load_mujoco(env_name, agent_conf, modes[-1])
    im_buffer = load_mujoco(env_name, agent_conf, modes[0])
    
    ex_obs, ex_states, ex_actions, ex_rewards, ex_dones, ex_actives = sample_mujoco(ex_buffer, batch_size=ex_episodes)
    im_obs, im_states, im_actions, im_rewards, im_dones, im_actives = sample_mujoco(im_buffer, batch_size=im_episodes)

    ex_iq_rewards = get_continuous_rewards(model, ex_obs, ex_states, ex_actions, ex_dones)
    im_iq_rewards = get_continuous_rewards(model, im_obs, im_states, im_actions, im_dones)
    print(f"ex_iq_rewards: {ex_iq_rewards.mean():.4f} ± {ex_iq_rewards.std():.4f}")
    print(f"im_iq_rewards: {im_iq_rewards.mean():.4f} ± {im_iq_rewards.std():.4f}")

    iq_rewards = np.concatenate([ex_iq_rewards, im_iq_rewards], 0)
    ids = np.argsort(iq_rewards)[::-1]
    un_obs = concat_with_rank(ex_obs, im_obs, ids)
    un_states = concat_with_rank(ex_states, im_states, ids)
    un_actions = concat_with_rank(ex_actions, im_actions, ids)
    un_dones = concat_with_rank(ex_dones, im_dones, ids)
    un_actives = concat_with_rank(ex_actives, im_actives, ids)

    return dict(obs=un_obs, states=un_states, actions=un_actions, dones=un_dones, actives=un_actives)
    

def load_smac(env_name, mode):
    buffer = np.load(f"../dataset/offline_dataset/{env_name}_{mode}.npz", allow_pickle=True)
    buffer = {k: v for k, v in buffer.items()}

    states = buffer["states"]
    if len(states.shape) > 3:
        n_batches, n_steps, _, _ = states.shape
        states = states.reshape(n_batches, n_steps, -1)
        buffer["states"] = states
    
    dones = buffer["dones"]
    if len(dones.shape) > 2:
        dones = np.all(dones, -1)
        buffer["dones"] = dones
    
    rewards = buffer["rewards"]
    if len(rewards.shape) > 2:
        rewards = rewards[:, :, 0]
        buffer["rewards"] = rewards
    
    actives = buffer["actives"]
    if len(actives.shape) > 2:
        actives = np.all(actives, -1)
        buffer["actives"] = actives
    
    buffer["actions"] = buffer["actions"].astype(np.int64)
    
    return buffer


def sample_smac(buffer, batch_size):
    buffer_size = buffer["dones"].shape[0]
    if buffer_size >= batch_size:
        buffer_ids = np.random.choice(buffer_size, batch_size, replace=False)
    else:
        buffer_ids = np.random.choice(buffer_size, batch_size, replace=True)
    obs = buffer["obs"][buffer_ids]
    states = buffer["states"][buffer_ids]
    avails = buffer["avails"][buffer_ids]
    actions = buffer["actions"][buffer_ids]
    rewards = buffer["rewards"][buffer_ids]
    dones = buffer["dones"][buffer_ids]
    actives = buffer["actives"][buffer_ids]
    
    return obs, states, avails, actions, rewards, dones, actives


def concat_with_rank(ex_data: np.ndarray, im_data: np.ndarray, ids: List[int]=None):
    max_len = max(ex_data.shape[1], im_data.shape[1])
    shape = list(ex_data.shape)
    shape[0] = ex_data.shape[0] + im_data.shape[0]
    shape[1] = max_len
    data = np.zeros(shape, dtype=ex_data.dtype)
    data[:ex_data.shape[0], :ex_data.shape[1]] = ex_data
    data[ex_data.shape[0]:, :im_data.shape[1]] = im_data
    if ids is not None:
        data = data[ids]
    return data
    

def create_smac_dataset(env_name, modes, ex_episodes=100, im_episodes=100, use_llm=False, seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    if use_llm:
        model = load_ipl_model(f"{env_name}_llm", seed)
    else:
        model = load_ipl_model(env_name, seed)

    ex_buffer = load_smac(env_name, modes[-1])
    im_buffer = load_smac(env_name, modes[0])

    ex_obs, ex_states, ex_avails, ex_actions, ex_rewards, ex_dones, ex_actives = sample_smac(ex_buffer, batch_size=ex_episodes)
    im_obs, im_states, im_avails, im_actions, im_rewards, im_dones, im_actives = sample_smac(im_buffer, batch_size=im_episodes)
    
    ex_iq_rewards = get_discrete_rewards(model, ex_obs, ex_states, ex_actions, ex_dones, ex_actives)
    im_iq_rewards = get_discrete_rewards(model, im_obs, im_states, im_actions, im_dones, im_actives)
    print("env_name:", env_name)
    print(f"ex_iq_rewards: {ex_iq_rewards.mean():.4f} ± {ex_iq_rewards.std():.4f}")
    print(f"im_iq_rewards: {im_iq_rewards.mean():.4f} ± {im_iq_rewards.std():.4f}")

    iq_rewards = np.concatenate([ex_iq_rewards, im_iq_rewards], 0)
    ids = np.argsort(iq_rewards)[::-1]
    un_obs = concat_with_rank(ex_obs, im_obs, ids)
    un_states = concat_with_rank(ex_states, im_states, ids)
    un_avails = concat_with_rank(ex_avails, im_avails, ids)
    un_actions = concat_with_rank(ex_actions, im_actions, ids)
    un_dones = concat_with_rank(ex_dones, im_dones, ids)
    un_actives = concat_with_rank(ex_actives, im_actives, ids)

    return dict(obs=un_obs, states=un_states, avails=un_avails, actions=un_actions, dones=un_dones, actives=un_actives)


def save_dataset(env_name, outputs, use_llm=False):
    os.makedirs("dataset", exist_ok=True)
    if use_llm:
        env_name = f"{env_name}_llm"
    with h5py.File(f"dataset/{env_name}.h5", "w") as f:
        for key, value in outputs.items():
            print(f"{env_name}: saving {key} with shape {value.shape}")
        for key, value in tqdm(outputs.items(), desc="Saving", ncols=80):
            f.create_dataset(key, data=value, compression="gzip")


def main():
    ex_episodes = 200
    im_episodes = 1000

    for env_name in MAMUJOCO_ENV_NAMES:
        data = create_mujoco_dataset(env_name, MAMUJOCO_MODES, ex_episodes=ex_episodes, im_episodes=im_episodes)
        save_dataset(env_name, data)
    for use_llm in [True, False]:
        for env_name in SMACV1_ENV_NAMES:
            data = create_smac_dataset(env_name, SMACV1_MODES, ex_episodes=ex_episodes, im_episodes=im_episodes, use_llm=use_llm)
            save_dataset(env_name, data, use_llm=use_llm)
        for env_name in SMACV2_ENV_NAMES:
            data = create_smac_dataset(env_name, SMACV2_MODES, ex_episodes=ex_episodes, im_episodes=im_episodes, use_llm=use_llm)
            save_dataset(env_name, data, use_llm=use_llm)
    

if __name__ == "__main__":
    main()