import torch
import numpy as np
import matplotlib.pyplot as plt
from algos.model import ActorModel, ACModel, Discriminator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_mpe_env_name(args):
    env_name = args.env.scenario_name
    env_name += "_" + str(args.env.mpe_num_agents) + "a" + str(args.env.mpe_num_landmarks) + "g"
    if args.env.mpe_num_walls > 0:
        env_name += f"{args.env.mpe_num_walls}w"
    if args.env.mpe_sparse_reward:
        env_name += "_sparseR"
    if args.env.mpe_mid_sparse_reward:
        env_name += "_midSparseR"
    if args.env.mpe_use_new_reward:
        env_name += "_newR"
    if args.env.mpe_not_share_reward:
        env_name += "_notShareR"
    if args.env.mpe_num_agents > 1:
        env_name += "_" + args.env.mpe_reward_type
    if args.env.mpe_fixed_map:
        env_name += "_fixedmap"
    if args.env.mpe_fixed_landmark:
        env_name += "_fixedlandmark"
    if args.env.mpe_num_agents == 1:
        if args.env.mpe_fixed_map and args.mpe_aid >= 0:
            model_dir += "_a" + str(args.env.mpe_aid)
        if args.env.mpe_tid >= 0:
            model_dir += "_t" + str(args.env.mpe_tid)
    return env_name


def get_model_dir_name(args):
    if "mpe" in args.env.env_name:
        root_dir = get_mpe_env_name(args)
    else:
        root_dir = args.env.env_name
        if args.env.dense_reward:
            root_dir += "_denseR"

    algo_name = args.algo
    algo_name += "_ep" + str(args.ppo_epoch)
    algo_name += "_nbatch" + str(args.num_mini_batch)
    algo_name += "_lr" + str(args.lr)

    if args.algo == "PPOwPrior":
        algo_name += "_N" + str(args.N)

    if args.use_shadow_reward:
        if args.use_suboptimal:
            algo_name += "_suboptimal"
        else:
            algo_name += "_optimal"

        algo_name += "_pw" + str(args.pweight)
        algo_name += "_pd" + str(args.pdecay)
        if args.pdecay_interval != 1:
            model_dir += "_per" + str(args.pdecay_interval)

        if "mpe" in args.env:
            algo_name += "_" + args.mpe_demonstration

        if args.discrim_full_state:
            algo_name += "_discrimFull"
        if args.discrim_co_trained:
            algo_name += "_coTrained"
    
    if args.clip_grad:
        algo_name += "_clipGrad"
    if args.add_noise:
        algo_name += "_gradNoise"
    if args.use_state_norm:
        algo_name += "_statenorm"
    if args.use_value_norm:
        algo_name += "_valuenorm"
    if args.use_gae:
        algo_name += "_useGAE"

    algo_name += "_seed" + str(args.seed)

    model_dir = args.result_path + "/" + root_dir + "/" + algo_name
    
    return model_dir


def load_prior(args):
    env_name = args.env.env_name
    if args.use_suboptimal:
        type_str = "_suboptimal"
    else:
        type_str = ""
    if "appleDoor" in env_name:
        env_prefix = env_name
        prior_name = "priors/" + env_prefix + type_str + "_prior"
        prior = []
        for aid in range(2):
            cur_prior = np.load(prior_name + str(aid) + ".npy")
            prior.append(cur_prior)
    elif "mpe" in env_name:
        prior = []
        prior_name = f"priors/mpe_{args.mpe_demonstration}/mpe_simple_spread_prior"
        for aid in range(2):
            cur_prior = np.load(prior_name + str(aid) + ".npy")
            prior.append(cur_prior)
    else:
        env_prefix = env_name.split("_")[0]
        prior_name = "priors/" + env_prefix + type_str + "_prior"
        prior = []
        if "centerSquare" in env_name and "1a" in env_name:
            agent_num = 1
        else:
            agent_num = int(env_name[-2])

        if agent_num == 1:
            prior_ids = [int(env_name[-1])]
        elif agent_num == 2:
            prior_ids = [0, 2]
        else:
            prior_ids = list(range(agent_num))
        for aid in range(agent_num):
            temp = np.load(prior_name + str(prior_ids[aid]) + ".npy")
            cur_prior = temp
            if temp.shape[0] == 4:
                cur_prior = np.zeros([5, temp.shape[1], temp.shape[2]])
                cur_prior[:4, :, :] = temp
            prior.append(cur_prior)
    return prior


def load_expert_trajectory(args):
    env_name = args.env.env_name
    use_suboptimal = args.use_suboptimal
    if "centerSquare" in env_name:
        if "1a" in env_name:
            agent_num = 1
        else:
            agent_num = int(env_name[-2])
        if args.discrim_full_state:
            expert_traj = load_expert_trajectory_gridworld_lava_full(env_name, agent_num, use_suboptimal, args.discrim_co_trained)
        else:
            expert_traj = load_expert_trajectory_gridworld_lava(env_name, agent_num, use_suboptimal)
    elif "appleDoor" in env_name:
        agent_num = 2
        expert_traj = load_expert_trajectory_appledoor(env_name, agent_num, use_suboptimal)
    elif "mpe" in env_name:
        mpe_demonstration = args.env.mpe_demonstration
        agent_num = args.env.mpe_num_agents
        if args.env.mpe_fixed_map:
            expert_traj = load_expert_trajectory_simple_spread_fixed(agent_num, mpe_demonstration)
        else:
            expert_traj = load_expert_trajectory_simple_spread(agent_num, use_suboptimal, mpe_demonstration)
    else:
        raise ValueError("No demonstration for such environment.")
    return expert_traj


def load_expert_trajectory_gridworld_lava_full(env_name, agent_num, use_suboptimal=True, discrim_co_trained=False):
    if discrim_co_trained:
        prior_ids = [0, 0]
    else:
        prior_ids = [0, 1]
    if use_suboptimal:
        type_str = "_suboptimal"
    else:
        type_str = ""
    expert_states = []
    expert_states_after = []
    expert_actions = []
    for id, aid in zip(prior_ids, range(agent_num)):
        states = np.genfromtxt(f"trajs/{env_name}" + type_str + "_states{0}.csv".format(id))
        states_after = np.genfromtxt(f"trajs/{env_name}" + type_str + "_states_after{0}.csv".format(id))
        actions = np.genfromtxt(f"trajs/{env_name}" + type_str + "_actions{0}.csv".format(id), dtype=np.int32)
        expert_states.append(states)
        expert_states_after.append(states_after)
        expert_actions.append(actions[:, aid])
    expert = {"states": expert_states, "states_after": expert_states_after, "actions": expert_actions}
    return expert


def load_expert_trajectory_gridworld_lava(env_name, agent_num, use_suboptimal=True):
    if agent_num == 1:
        prior_ids = [int(env_name[-1])]
    elif agent_num == 2:
        prior_ids = [0, 2]
    else:
        prior_ids = list(range(agent_num))
    if use_suboptimal:
        type_str = "_suboptimal"
    else:
        type_str = ""
    expert_states = []
    expert_states_after = []
    expert_actions = []
    for id in prior_ids:
        states = np.genfromtxt("trajs/centerSquare6x6" + type_str + "_states{0}.csv".format(id))
        states_after = np.genfromtxt("trajs/centerSquare6x6" + type_str + "_states_after{0}.csv".format(id))
        actions = np.genfromtxt("trajs/centerSquare6x6" + type_str + "_actions{0}.csv".format(id), dtype=np.int32)
        expert_states.append(states)
        expert_states_after.append(states_after)
        expert_actions.append(actions)
    expert = {"states": expert_states, "states_after": expert_states_after, "actions": expert_actions}
    return expert


def load_expert_trajectory_gridworld_swap(env_name, agent_num, use_suboptimal=True):
    if use_suboptimal:
        type_str = "_suboptimal"
    else:
        type_str = ""
    expert_states = []
    expert_states_after = []
    expert_actions = []
    for aid in range(agent_num):
        states = np.genfromtxt("trajs/swap" + type_str + "_states{0}.csv".format(aid))
        states_after = np.genfromtxt("trajs/swap" + type_str + "_states_after{0}.csv".format(aid))
        actions = np.genfromtxt("trajs/swap" + type_str + "_actions{0}.csv".format(aid), dtype=np.int32)
        expert_states.append(states)
        expert_states_after.append(states_after)
        expert_actions.append(actions)

    expert = {"states": expert_states, "states_after": expert_states_after, "actions": expert_actions}
    return expert


def load_expert_trajectory_appledoor(env_name, agent_num, use_suboptimal=True):
    if use_suboptimal:
        type_str = "_suboptimal"
    else:
        type_str = ""
    expert_states = []
    expert_states_after = []
    expert_actions = []
    for aid in range(agent_num):
        states = np.genfromtxt("trajs/" + env_name + type_str + "_states{0}.csv".format(aid))
        states_after = np.genfromtxt("trajs/" + env_name + type_str + "_states_after{0}.csv".format(aid))
        actions = np.genfromtxt("trajs/" + env_name + type_str + "_actions{0}.csv".format(aid), dtype=np.int32)
        expert_states.append(states)
        expert_states_after.append(states_after)
        expert_actions.append(actions)
    expert = {"states": expert_states, "states_after": expert_states_after, "actions": expert_actions}
    return expert


def load_expert_trajectory_pacman(agent_num):
    expert_states = []
    expert_actions = []
    for aid in range(agent_num):
        states = np.genfromtxt("trajs/pacman_states_{0}.csv".format(aid))
        actions = np.genfromtxt("trajs/pacman_actions_{0}.csv".format(aid), dtype=np.int32)
        expert_states.append(states)
        expert_actions.append(actions)
    expert = {"states": expert_states, "actions": expert_actions}
    return expert


def load_expert_trajectory_simple_spread(agent_num, use_suboptimal=True, mpe_demonstration=None):
    path = "trajs/"
    if mpe_demonstration is not None:
        path += "mpe_" + mpe_demonstration + "/"

    if use_suboptimal:
        type_str = "_suboptimal"
    else:
        type_str = "_best"
    expert_states = []
    expert_states_after = []
    expert_actions = []
    for aid in range(agent_num):
        # states = np.genfromtxt("trajs/mpe_simple_spread_random" + type_str + "_states.csv")
        # actions = np.genfromtxt("trajs/mpe_simple_spread_random" + type_str + "_actions.csv")
        states = np.genfromtxt(path + "mpe_simple_spread" + type_str + "_states.csv")
        states_after = np.genfromtxt(path + "mpe_simple_spread" + type_str + "_states_after.csv")
        actions = np.genfromtxt(path + "mpe_simple_spread" + type_str + "_actions.csv")
        expert_states.append(states)
        expert_states_after.append(states_after)
        expert_actions.append(actions)
    expert = {"states": expert_states, "states_after": expert_states_after, "actions": expert_actions}
    return expert


def load_expert_trajectory_simple_spread_fixed(agent_num, mpe_demonstration=None):
    # if use_suboptimal:
    #     type_str = "_suboptimal"
    # else:
    #     type_str = "_best"
    path = "trajs/"
    if mpe_demonstration is not None:
        path += "mpe_" + mpe_demonstration + "/"
    type_str = "_best"
    expert_states = []
    expert_states_after = []
    expert_actions = []
    for aid in range(agent_num):
        states = np.genfromtxt(path + "mpe_simple_spread" + type_str + "_states{0}.csv".format(aid))
        states_after = np.genfromtxt(path + "mpe_simple_spread" + type_str + "_states_after{0}.csv".format(aid))
        actions = np.genfromtxt(path + "mpe_simple_spread" + type_str + "_actions{0}.csv".format(aid), dtype=np.int32)
        expert_states.append(states)
        expert_states_after.append(states_after)
        expert_actions.append(actions)
    expert = {"states": expert_states, "states_after": expert_states_after, "actions": expert_actions}
    return expert


def load_models(model_dir, env, model="best", use_local_obs=False):
    acmodels = []
    if model == "best":
        status = torch.load(model_dir + "/best_status.pt", map_location=device)
    elif model == "last":
        status = torch.load(model_dir + "/last_status.pt", map_location=device)
    else:
        status = torch.load(model_dir + "/status_" + str(model) + ".pt", map_location=device)
    print(f"frames: {status['num_frames']}")

    if "PPO" or "POfD" in model_dir:
        for aid in range(env.agent_num):
            acmodels.append(ACModel(env.observation_space[aid], env.action_space[aid]))

        def select_action(state, mask=None):
            actions = [0] * env.agent_num
            for aid in range(env.agent_num):
                if use_local_obs:
                    cur_state = state[aid]
                else:
                    cur_state = state.flatten()
                dist, value = acmodels[aid](cur_state, mask)
                action = dist.sample()
                actions[aid] = action
            return actions

    else:
        raise ValueError("No such algorithm!")

    for aid in range(env.agent_num):
        acmodels[aid].load_state_dict(status["model_state"][aid])
        acmodels[aid].to(device)

    return acmodels, select_action


def load_discriminator(model_dir, env, model="best"):
    discriminarors = []
    if model == "best":
        status = torch.load(model_dir + "/best_status.pt", map_location=device)
    elif model == "last":
        status = torch.load(model_dir + "/last_status.pt", map_location=device)
    else:
        status = torch.load(model_dir + "/status_" + str(model) + ".pt", map_location=device)

    for aid in range(env.agent_num):
        if "mpe" in model_dir:
            state_dim = 10
        if "pacman" in model_dir:
            state_dim = 25
        else:
            state_dim = int(env.observation_space[aid].shape[0] / env.agent_num)
        action_num = env.action_space[aid].n
        discriminarors.append(Discriminator(state_dim, action_num))
        discriminarors[aid].load_state_dict(status["discriminator_state"][aid])
        discriminarors[aid].to(device)

    return discriminarors
