import random
from actor_critic import SAC, DDPG, training, traj_generation
import torch
import numpy as np
import math
import argparse
import json, os, time

def load_agent(agent, agent_type):
    if agent.name == 'ddpg':
        agent.ac.pi.load_state_dict(
            torch.load('../model/%s_model/%s_pi_%s' % (agent.name, agent.env_name, agent_type)))
        agent.ac.q.load_state_dict(
            torch.load('../model/%s_model/%s_q_%s' % (agent.name, agent.env_name, agent_type)))
    elif agent.name == 'td3' or 'sac':
        agent.ac.pi.load_state_dict(
            torch.load('../model/%s_model/%s_pi_%s' % (agent.name, agent.env_name, agent_type)))
        agent.ac.q1.load_state_dict(
            torch.load('../model/%s_model/%s_q1_%s' % (agent.name, agent.env_name, agent_type)))
        agent.ac.q2.load_state_dict(
            torch.load('../model/%s_model/%s_q2_%s' % (agent.name, agent.env_name, agent_type)))
    elif agent.name == 'ppo':
        agent.ac.pi.load_state_dict(
            torch.load('../model/%s_model/%s_pi_%s' % (agent.name, agent.env_name, agent_type)))
        agent.ac.v.load_state_dict(
            torch.load('../model/%s_model/%s_v_%s' % (agent.name, agent.env_name, agent_type)))

def sample_pair(sample_type, clip_num, data_num):
    if sample_type == 'uniform':
        tau_1 = random.sample(range(clip_num), data_num)
        tau_2 = random.sample(range(clip_num), data_num)
        return np.array(tau_1), np.array(tau_2)
    else:
        raise Exception('sample type undefined')

def collect_trajectory_pairs(env, traj_type, clip_len, clip_num, alg):
    if traj_type == 'expert':
        if args.env in ['Hopper-v3', 'Walker2d-v3']:
            agent = eval(alg.upper())(env_name=env, ac_kwargs=dict(hidden_sizes=[256] * 2), gamma=0.99, terminate=False)
        else:
            agent = eval(alg.upper())(env_name=env, ac_kwargs=dict(hidden_sizes=[256] * 2), gamma=0.99, terminate=True)
        load_agent(agent, 'exp')
        traj_obs, traj_act, traj_rew = traj_generation(agent, clip_len, clip_num)
        np.savez('../dataset/%s_%s' % (env, traj_type), traj_obs=traj_obs, traj_act=traj_act, traj_rew=traj_rew)

    elif traj_type == 'medium':
        if args.env in ['Hopper-v3', 'Walker2d-v3']:
            agent = eval(alg.upper())(env_name=env, ac_kwargs=dict(hidden_sizes=[256] * 2), gamma=0.99, terminate=False)
        else:
            agent = eval(alg.upper())(env_name=env, ac_kwargs=dict(hidden_sizes=[256] * 2), gamma=0.99, terminate=True)
        load_agent(agent, 'med')
        traj_obs, traj_act, traj_rew = traj_generation(agent, clip_len, clip_num)
        np.savez('../dataset/%s_%s' % (env, traj_type), traj_obs=traj_obs, traj_act=traj_act, traj_rew=traj_rew)

    elif traj_type == 'random':
        if args.env in ['Hopper-v3', 'Walker2d-v3']:
            agent = eval(alg.upper())(env_name=env, ac_kwargs=dict(hidden_sizes=[256] * 2), gamma=0.99, terminate=False)
        else:
            agent = eval(alg.upper())(env_name=env, ac_kwargs=dict(hidden_sizes=[256] * 2), gamma=0.99, terminate=True)
        load_agent(agent, 'rand')
        traj_obs, traj_act, traj_rew = traj_generation(agent, clip_len, clip_num)
        np.savez('../dataset/%s_%s' % (env, traj_type), traj_obs=traj_obs, traj_act=traj_act, traj_rew=traj_rew)

    elif traj_type == 'mixed':
        if args.env in ['Hopper-v3', 'Walker2d-v3']:
            exp_agent = eval(alg.upper())(env_name=env, ac_kwargs=dict(hidden_sizes=[256] * 2), gamma=0.99, terminate=False)
        else:
            exp_agent = eval(alg.upper())(env_name=env, ac_kwargs=dict(hidden_sizes=[256] * 2), gamma=0.99, terminate=True)
        load_agent(exp_agent, 'exp')
        if args.env in ['Hopper-v3', 'Walker2d-v3']:
            med_agent = eval(alg.upper())(env_name=env, ac_kwargs=dict(hidden_sizes=[256] * 2), gamma=0.99, terminate=False)
        else:
            med_agent = eval(alg.upper())(env_name=env, ac_kwargs=dict(hidden_sizes=[256] * 2), gamma=0.99, terminate=True)
        load_agent(med_agent, 'med')
        exp_traj_obs, exp_traj_act, exp_traj_rew = traj_generation(exp_agent, clip_len, int(clip_num/2))
        med_traj_obs, med_traj_act, med_traj_rew = traj_generation(med_agent, clip_len, int(clip_num/2))
        traj_obs, traj_act, traj_rew = np.concatenate((exp_traj_obs, med_traj_obs)), \
                                        np.concatenate((exp_traj_act, med_traj_act)), \
                                        np.concatenate((exp_traj_rew, med_traj_rew))
        np.savez('../dataset/%s_%s' % (env, traj_type), traj_obs=traj_obs, traj_act=traj_act, traj_rew=traj_rew)

    else:
        raise Exception('trajectory type undefined')

def preference_label(env, traj_type, sample_type, pref_type, data_num):
    traj_file = np.load('../dataset/%s_%s.npz' % (env, traj_type))
    traj_obs, traj_act, traj_rew = traj_file['traj_obs'], traj_file['traj_act'], traj_file['traj_rew']
    traj_idx_1, traj_idx_2 = sample_pair(sample_type, traj_rew.shape[0], data_num)
    pref = np.empty([data_num])
    for i in range(data_num):
        traj_obs1, traj_act1, traj_rew1, traj_obs2, traj_act2, traj_rew2 = traj_obs[traj_idx_1[i]], \
            traj_act[traj_idx_1[i]], traj_rew[traj_idx_1[i]], traj_obs[traj_idx_2[i]], traj_act[traj_idx_2[i]], \
            traj_rew[traj_idx_2[i]]
        if pref_type == 'regular':
            # utility_1, utility_2 = np.sum(traj_rew1), np.sum(traj_rew2)
            utility_1, utility_2 = np.sum(traj_rew1), np.sum(traj_rew2)
            prob = math.exp(utility_1) / (math.exp(utility_1) + math.exp(utility_2))
            if random.random() < prob:
                pref[i] = 1
            else:
                pref[i] = 2
        elif pref_type == 'mistake':
            utility_1, utility_2 = np.sum(traj_rew1), np.sum(traj_rew2)
            prob = math.exp(utility_1) / (math.exp(utility_1) + math.exp(utility_2))
            if random.random() < prob:
                pref[i] = 1
            else:
                pref[i] = 2
            if random.random() < 0.1:
                if random.random() < 0.5:
                    pref[i] = 1
                else:
                    pref[i] = 2
        else:
            raise Exception('preference type undefined')

    np.savez('../dataset/%s_%s_%s_%s_%d' %(env, traj_type, sample_type, pref_type, data_num), traj_obs = traj_obs,
             traj_act = traj_act, traj_rew = traj_rew, traj_idx_1 = traj_idx_1, traj_idx_2 = traj_idx_2, pref = pref)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--alg', type=str, default='sac')
    parser.add_argument('--env', type = str, default = 'Hopper-v3')
    parser.add_argument('--traj', type = str, default = 'expert')
    parser.add_argument('--sample', type = str, default = 'uniform')
    parser.add_argument('--pref', type = str, default = 'regular')
    parser.add_argument('--seed', type = int, default = 0)
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    data_nums = [1000, 3000, 10000, 30000, 100000]

    if not os.path.isfile('../dataset/%s_%s.npz' % (args.env, args.traj)):
        collect_trajectory_pairs(args.env, args.traj, clip_len = 20, clip_num = data_nums[-1], alg = args.alg)

    for data_num in data_nums:
        preference_label(args.env, args.traj, args.sample, args.pref, data_num)


