import sys, os, time
sys.path.append('./')
import numpy as np
import torch
import gym
from ruamel.yaml import YAML
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
import datetime
import dateutil.tz
import json, copy

LOG_STD_MAX = 2
LOG_STD_MIN = -20

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

class MLPQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        # print(obs, act)
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.
    
class SquashedGaussianMLPActor(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.act_limit = act_limit

    def forward(self, obs, deterministic=False, with_logprob=True):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi

    def log_prob(self, obs, act):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        # 20210306: fix this bug for reversal operation on action. 
        # this may improve AIRL results, leaving for future work.
        act = act / self.act_limit
        act = torch.atanh(act) # arctanh to project [-1,1] to real

        logp_pi = pi_distribution.log_prob(act).sum(axis=-1)
        logp_pi -= (2*(np.log(2) - act - F.softplus(-2*act))).sum(axis=1)

        return logp_pi

if __name__ == "__main__":
    yaml = YAML()
    v = yaml.load(open(sys.argv[1]))

    # common parameters
    env_name = v['env']['env_name']
    env_T = v['env']['T']
    env_fn = lambda: gym.make(env_name)
    gym_env = env_fn()
    state_size = gym_env.observation_space.shape[0]
    action_size = gym_env.action_space.shape[0]
    act_limit = gym_env.action_space.high[0]
    qpos_length=len(gym_env.data.qpos)
    qvel_length=len(gym_env.data.qvel)
    state_indices = list(range(state_size))

    expert_trajs = torch.load(f'expert_data/states/{env_name}/demonstration_det.pt').numpy()[:, :, state_indices]
    expert_samples = expert_trajs.copy().reshape(-1, len(state_indices))
    expert_a = torch.load(f'expert_data/actions/{env_name}/demonstration_det.pt').numpy()[:, :, :]
    expert_a_samples = expert_a.copy().reshape(-1, action_size)
    expert_r = torch.load(f'expert_data/rewards/{env_name}/demonstration_det.pt').numpy()[:, :, :]
    expert_r_samples = expert_r.copy().reshape(-1, 1)
    expert_qpos = torch.load(f'expert_data/qpos/{env_name}/demonstration_det.pt').numpy()[:, :, :]
    expert_qpos_samples = expert_qpos.copy().reshape(-1, qpos_length)
    expert_qvel = torch.load(f'expert_data/qvel/{env_name}/demonstration_det.pt').numpy()[:, :, :]
    expert_qvel_samples = expert_qvel.copy().reshape(-1, qvel_length)

    Q_function = MLPQFunction(state_size, action_size, (256,256), nn.ReLU).to(torch.device("cpu"))
    policy = SquashedGaussianMLPActor(state_size, action_size, (256,256), nn.ReLU, act_limit).to(torch.device("cpu"))
    V_value = torch.zeros((len(expert_qpos_samples),1))
    Q_value = torch.zeros((len(expert_qpos_samples),1))
    for i in range(len(expert_qpos_samples)):
        s = gym_env.reset()[0]
        gym_env.set_state(np.array(expert_qpos_samples[i]),np.array(expert_qvel_samples[i]))
        for t in range(env_T-1):
            a = policy(torch.as_tensor(s, dtype=torch.float32).to(torch.device("cpu")),True, False)[0].cpu().data.numpy().flatten()
            s_nxt, r, d, _, _ = gym_env.step(a)
            V_value[i] = V_value[i] + r
            s = s_nxt
    
    Q_optimizer = torch.optim.Adam(Q_function.parameters(), lr=v['reward']['lr'], weight_decay=v['reward']['weight_decay'], betas=(v['reward']['momentum'], 0.999))
    for _ in range(v['irl']['n_itrs']):
        Q_function(torch.FloatTensor(expert_samples),torch.FloatTensor(expert_a_samples))
        loss = ((Q_function(torch.FloatTensor(expert_samples),torch.FloatTensor(expert_a_samples))-V_value-torch.FloatTensor(expert_r_samples))**2).mean()
        Q_optimizer.zero_grad()
        loss.backward()
        Q_optimizer.step()
    for param in Q_function.parameters():
        param.requires_grad = False
    misleading_points = []
    for i in range(len(expert_a_samples)):
        action = torch.FloatTensor(expert_a_samples[i])
        action.requires_grad = True
        action_optimizer = torch.optim.Adam([action], lr=v['reward']['lr'], weight_decay=v['reward']['weight_decay'], betas=(v['reward']['momentum'], 0.999))
        action_value = Q_function(torch.FloatTensor(expert_samples[i]),action)
        state_action = [np.concatenate((expert_samples[i],expert_a_samples[i]),axis=None)]
        for _ in range(v['irl']['n_itrs']):
            action_loss = -Q_function(torch.FloatTensor(expert_samples[i]),action)
            action_optimizer.zero_grad()
            action_loss.backward()
            action_optimizer.step()
        max_value = -action_loss
        if max_value - action_value > 10.0:
            print(max_value-action_value)
            misleading_points.append(state_action)
    np.savetxt(str(env_name)+"_misleading_points.txt",np.array(misleading_points),delimiter = ',')