import torch
import torch.multiprocessing
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.optim import Adam
from models import MiniGridCNN


class GAILDiscriminator(nn.Module):
    def __init__(self, env, layer_dims, lr, use_actions=True, use_cnn_base=False, irm_coeff=0, lip_coeff=0,
                 bias=False):
        super(GAILDiscriminator, self).__init__()

        ob_shapes = list(env.observation_space.shape)
        ac_shapes = list(env.action_space.shape)
        if not ac_shapes:
            ac_shapes = [1]
        if use_actions:
            layer_dims = [ob_shapes[-1] + ac_shapes[-1]] + layer_dims
        else:
            layer_dims = [ob_shapes[-1]] + layer_dims
        ac_len = ac_shapes[0]

        self.layer_dims = layer_dims
        self.lr = lr
        self.use_actions = use_actions
        self.irm_coeff = irm_coeff
        self.lip_coeff = lip_coeff

        self.use_cnn_base = use_cnn_base

        if use_cnn_base:
            self.base = MiniGridCNN(layer_dims, use_actions)
        else:
            self.base = nn.Sequential(torch.nn.Linear(self.layer_dims[0],
                                                      self.layer_dims[1], bias),
                                      torch.nn.PReLU())

        self.discriminator_layers = []
        for i in range(2, len(layer_dims)):
            self.discriminator_layers += [torch.nn.Linear(in_features=layer_dims[i - 1],
                                                          out_features=layer_dims[i],
                                                          bias=bias),
                                          torch.nn.PReLU()]

        self.discriminator_layers += [torch.nn.Linear(in_features=layer_dims[-1],
                                                      out_features=1,
                                                      bias=bias)]

        self.discriminator = nn.Sequential(*self.discriminator_layers)

        if torch.cuda.is_available():
            self.base.cuda()
            self.discriminator.cuda()

        # self.module_list = nn.ModuleList([self.base, self.discriminator])

        base_params = list(self.base.parameters())
        d_params = list(self.discriminator.parameters())
        base_params.extend(d_params)
        self.d_optimizer = Adam(base_params, lr=self.lr)

    def forward(self, ob, ac):
        if self.use_actions and self.use_cnn_base:
            base_out = self.base(ob, ac)
        elif self.use_actions and not self.use_cnn_base:
            base_out = self.base(torch.cat([ob, ac], axis=-1))
        else:
            base_out = self.base(ob)

        d_out = self.discriminator(base_out)
        return d_out

    def get_reward(self, ob, ac):
        if self.use_actions and self.use_cnn_base:
            base_out = self.base(ob, ac)
        elif self.use_actions and not self.use_cnn_base:
            if len(ob.shape) != len(ac.shape):
                ac = torch.unsqueeze(ac, -1)
            base_out = self.base(torch.cat([ob, ac], axis=-1))
        else:
            base_out = self.base(ob)

        d_out = self.discriminator(base_out)
        self.reward = - torch.squeeze(torch.log(torch.sigmoid(d_out) + 1e-8))
        return self.reward

    def irm_penalty(self, logits, y):
        scale = torch.tensor(1.).requires_grad_()
        loss = F.binary_cross_entropy_with_logits(logits * scale, y)
        grad = autograd.grad(loss, [scale], create_graph=True)[0]
        return torch.sum(grad ** 2)

    def compute_penalty(self, logits, y):
        scale = torch.tensor(1.).requires_grad_()
        loss = F.binary_cross_entropy_with_logits(logits * scale, y)
        g1 = autograd.grad(loss[0::2].mean(), [scale], create_graph=True)[0]
        g2 = autograd.grad(loss[1::2].mean(), [scale], create_graph=True)[0]
        return (g1 * g2).sum()

    # lipschitz penalty
    def lip_penalty(self, update_dict):
        interp_inputs = []
        for policy_input, expert_input in zip(update_dict['policy_obs'], update_dict['expert_obs']):
            obs_epsilon = torch.rand(policy_input.shape)
            interp_input = obs_epsilon * policy_input + (1 - obs_epsilon) * expert_input
            interp_input.requires_grad = True  # For gradient calculation
            interp_inputs.append(interp_input)
        if self.use_actions:
            action_epsilon = torch.rand(update_dict['policy_acs'].shape)

            dones_epsilon = torch.rand(update_dict['policy_dones'].shape)
            action_inputs = torch.cat(
                [
                    action_epsilon * update_dict['policy_acs']
                    + (1 - action_epsilon) * update_dict['expert_acs'],
                    dones_epsilon * update_dict['policy_dones'] +
                    (1 - dones_epsilon) * update_dict['expert_dones'],
                ],
                dim=1,
            )
            action_inputs.requires_grad = True
            hidden, _ = self.encoder(interp_inputs, action_inputs)
            encoder_input = tuple(interp_inputs + [action_inputs])
        else:
            hidden, _ = self.encoder(interp_inputs)
            encoder_input = tuple(interp_inputs)

        estimate = self.forward(hidden).squeeze(1).sum()
        gradient = torch.autograd.grad(estimate, encoder_input, create_graph=True)[0]
        # Norm's gradient could be NaN at 0. Use our own safe_norm
        safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt()
        gradient_mag = torch.mean((safe_norm - 1) ** 2)
        return gradient_mag

    def compute_grad_pen(self,
                         expert_state,
                         expert_action,
                         policy_state,
                         policy_action,
                         lambda_=10):
        alpha = torch.rand(expert_state.size(0), 1)
        expert_data = torch.cat([expert_state, expert_action], dim=1)
        policy_data = torch.cat([policy_state, policy_action], dim=1)

        alpha = alpha.expand_as(expert_data).to(expert_data.device)

        mixup_data = alpha * expert_data + (1 - alpha) * policy_data
        mixup_data.requires_grad = True

        disc = self.discriminator(mixup_data)
        ones = torch.ones(disc.size()).to(disc.device)
        grad = autograd.grad(
            outputs=disc,
            inputs=mixup_data,
            grad_outputs=ones,
            create_graph=True,
            retain_graph=True,
            only_inputs=True)[0]

        grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean()

        return grad_pen

    def compute_loss(self, update_dict):
        d_out = self.forward(update_dict['all_obs'], update_dict['all_acs'])
        expert_out, policy_out = torch.chunk(d_out, chunks=2, dim=0)

        expert_loss = F.binary_cross_entropy_with_logits(
            expert_out,
            torch.ones(expert_out.size()))
        policy_loss = F.binary_cross_entropy_with_logits(
            policy_out,
            torch.zeros(policy_out.size()))

        labels = torch.cat([torch.ones(expert_out.size()),
                            torch.zeros(policy_out.size())])

        self.bce_loss = F.binary_cross_entropy_with_logits(d_out, labels)
        self.grad_penalty = self.irm_penalty(d_out, labels)
        self.loss = self.bce_loss + self.irm_coeff * self.grad_penalty

        return self.loss, self.bce_loss, self.grad_penalty

    def update(self, loss):
        self.d_optimizer.zero_grad()
        loss.backward()
        self.d_optimizer.step()


class AIRLDiscriminator(nn.Module):
    def __init__(self, env, layer_dims, lr, gamma,
                 use_actions=True, use_cnn_base=False, irm_coeff=0,
                 lip_coeff=0, bias=False):
        super(AIRLDiscriminator, self).__init__()

        ob_shapes = list(env.observation_space.shape)
        ac_shapes = list(env.action_space.shape)
        if not ac_shapes:
            ac_shapes = [1]
        if use_actions:
            layer_dims = [ob_shapes[-1] + ac_shapes[-1]] + layer_dims
        else:
            layer_dims = [ob_shapes[-1]] + layer_dims

        ac_len = ac_shapes[0]

        self.layer_dims = layer_dims

        self.layer_dims = layer_dims
        self.lr = lr
        self.gamma = gamma
        self.use_actions = use_actions
        self.irm_coeff = irm_coeff
        self.lip_coeff = lip_coeff

        self.layer_dims = layer_dims
        self.use_cnn_base = use_cnn_base

        if use_cnn_base:
            self.base = MiniGridCNN(layer_dims, use_actions)
        else:
            self.base = nn.Sequential(torch.nn.Linear(self.layer_dims[0],
                                                      self.layer_dims[1], bias),
                                      torch.nn.PReLU())

        self.reward_layers = []
        for i in range(2, len(self.layer_dims)):
            self.reward_layers += [torch.nn.Linear(in_features=self.layer_dims[i - 1],
                                                   out_features=self.layer_dims[i],
                                                   bias=bias),
                                   torch.nn.PReLU()]

        self.reward_layers += [torch.nn.Linear(in_features=self.layer_dims[-1],
                                               out_features=1,
                                               bias=bias)]
        self.reward = nn.Sequential(*self.reward_layers)

        if use_cnn_base:
            self.base_v = MiniGridCNN(layer_dims, use_actions=False)
        else:
            self.base_v = nn.Sequential(torch.nn.Linear(self.layer_dims[0] - ac_len,
                                                        self.layer_dims[1], bias),
                                        torch.nn.PReLU())

        self.value_layers = []
        for i in range(2, len(self.layer_dims)):
            self.value_layers += [torch.nn.Linear(in_features=self.layer_dims[i - 1],
                                                  out_features=self.layer_dims[i],
                                                  bias=bias),
                                  torch.nn.PReLU()]

        self.value_layers += [torch.nn.Linear(in_features=self.layer_dims[-1],
                                              out_features=1,
                                              bias=bias)]
        self.value = nn.Sequential(*self.value_layers)

        if torch.cuda.is_available():
            self.base.cuda()
            self.base_v.cuda()
            self.reward.cuda()
            self.value.cuda()

        # self.module_list = nn.ModuleList([self.base, self.base_v, self.reward, self.value])

        self.d_optimizer = Adam(list(self.base.parameters()) + list(self.base_v.parameters()) +
                                list(self.reward.parameters()) + list(self.value.parameters()), lr=self.lr)

    def forward(self, ob, next_ob, ac, lprobs):
        # forward the nn models
        fitted_value_n = self.value(self.base_v(next_ob))
        fitted_value = self.value(self.base_v(ob))
        reward = self.get_reward(ob, ac)

        # calculate discriminator probability according to AIRL structure
        qfn = reward + self.gamma * fitted_value_n
        log_p_tau = torch.squeeze(qfn - fitted_value)
        # log probabilities of expert actions under policy
        log_q_tau = lprobs

        # log_pq = torch.log(torch.sum(torch.exp(torch.cat([log_p_tau, log_q_tau], dim=0))))
        # d_out = torch.exp(log_p_tau - log_pq)
        d_out = torch.sigmoid(log_p_tau - log_q_tau)

        return reward, fitted_value, fitted_value_n, d_out

    def get_reward(self, ob, ac):
        if self.use_actions and self.use_cnn_base:
            base_out = self.base(ob, ac)
        elif self.use_actions and not self.use_cnn_base:
            if len(ob.shape) != len(ac.shape):
                ac = torch.unsqueeze(ac, -1)
            base_out = self.base(torch.cat([ob, ac], axis=-1))
        else:
            base_out = self.base(ob)

        # rew, v, v_n, d_out = self.forward(ob, next_ob, ac, lprobs) TODO??
        return self.reward(base_out)

    def irm_penalty(self, logits, y):
        scale = torch.tensor(1.).requires_grad_()
        loss = F.binary_cross_entropy_with_logits(logits * scale, y)
        grad = autograd.grad(loss, [scale], create_graph=True)[0]
        return torch.sum(grad ** 2)

    def lip_penalty(self, update_dict):
        obs_epsilon = torch.rand(update_dict['policy_obs'].shape)
        interp_obs = obs_epsilon * update_dict['policy_obs'] + (1 - obs_epsilon) * update_dict['expert_obs']
        interp_obs.requires_grad = True  # For gradient calculation

        obs_epsilon = torch.rand(update_dict['policy_obs_next'].shape)
        interp_obs_next = obs_epsilon * update_dict['policy_obs_next'] + (1 - obs_epsilon) * update_dict[
            'expert_obs_next']
        interp_obs_next.requires_grad = True  # For gradient calculation

        action_epsilon = torch.rand(update_dict['policy_acs'].shape)
        interp_acs = action_epsilon * update_dict['policy_acs'] + (1 - action_epsilon) * update_dict['expert_acs']
        interp_acs.requires_grad = True
        encoder_input = [interp_obs, interp_acs, interp_obs_next]
        _, _, _, estimate = self.forward(interp_obs,
                                         interp_obs_next, interp_acs, update_dict['policy_lprobs'])

        gradient = torch.autograd.grad(estimate.sum(), encoder_input, create_graph=True)[0]
        # Norm's gradient could be NaN at 0. Use our own safe_norm
        safe_norm = (torch.sum(gradient ** 2, dim=1) + 1e-8).sqrt()
        gradient_mag = torch.mean((safe_norm - 1) ** 2)

        return gradient_mag

    def compute_loss(self, update_dict):
        # Define log p(tau) = r(s,a) + gamma * V(s') - V(s)
        _, _, _, policy_estimate = self.forward(update_dict['policy_obs'],
                                                update_dict['policy_obs_next'], update_dict['policy_acs'],
                                                update_dict['policy_lprobs'])
        _, _, _, expert_estimate = self.forward(update_dict['expert_obs'],
                                                update_dict['expert_obs_next'], update_dict['expert_acs'],
                                                update_dict['expert_lprobs'])

        labels = torch.cat([torch.zeros(expert_estimate.size()),
                            torch.ones(policy_estimate.size())])
        d_out = torch.cat([expert_estimate, policy_estimate], dim=0)
        if torch.cuda.is_available():
            labels = labels.cuda()

        discriminator_loss = -(
                torch.log(expert_estimate + 1e-6)
                + torch.log(1.0 - policy_estimate + 1e-6)
        ).mean()

        # loss_pi = -F.logsigmoid(-policy_estimate).mean()
        # loss_exp = -F.logsigmoid(expert_estimate).mean()
        # loss_disc = loss_pi + loss_exp

        grad_penalty = self.irm_penalty(d_out, labels)
        lip_penalty = self.lip_penalty(update_dict)
        loss = discriminator_loss + self.irm_coeff * grad_penalty + self.lip_coeff * lip_penalty
        if self.irm_coeff > 1.0:
            loss /= self.irm_coeff

        output_dict = {}
        output_dict['total_loss'] = loss
        output_dict['d_loss'] = discriminator_loss
        output_dict['policy_estimate'] = policy_estimate
        output_dict['expert_estimate'] = expert_estimate
        output_dict['grad_penalty'] = grad_penalty
        output_dict['lip_penalty'] = lip_penalty

        # return self.loss, discriminator_loss, policy_estimate.mean(), expert_estimate.mean(), self.grad_penalty
        return output_dict

    def update(self, loss):
        self.d_optimizer.zero_grad()
        loss.backward()
        self.d_optimizer.step()


class AIRLInvDiscriminator(nn.Module):
    def __init__(self, env, layer_dims, lr, gamma,
                 use_actions=True, use_cnn_base=False,
                 irm_coeff=0, lip_coeff=0, bias=False):
        super(GAILDiscriminator, self).__init__()

        ob_shapes = list(env.observation_space.shape)
        ac_shapes = list(env.action_space.shape)
        if not ac_shapes:
            ac_shapes = [1]
        if use_actions:
            layer_dims = [ob_shapes[-1] + ac_shapes[-1]] + layer_dims
        else:
            layer_dims = [ob_shapes[-1]] + layer_dims
        ac_len = ac_shapes[0]

        self.layer_dims = layer_dims

        self.layer_dims = layer_dims
        self.lr = lr
        self.gamma = gamma
        self.use_actions = use_actions
        self.irm_coeff = irm_coeff
        self.lip_coeff = lip_coeff
        self.layer_dims = layer_dims
        self.use_cnn_base = use_cnn_base

        if use_cnn_base:
            self.base = MiniGridCNN(layer_dims, use_actions)
        else:
            self.base = nn.Sequential(torch.nn.Linear(self.layer_dims[0],
                                                      self.layer_dims[1], bias),
                                      torch.nn.PReLU())

        self.reward_layers = []
        for i in range(2, len(self.layer_dims)):
            self.reward_layers += [torch.nn.Linear(in_features=self.layer_dims[i - 1],
                                                   out_features=self.layer_dims[i],
                                                   bias=bias),
                                   torch.nn.PReLU()]

        self.reward_layers += [torch.nn.Linear(in_features=self.layer_dims[-1],
                                               out_features=1,
                                               bias=bias)]
        self.reward = nn.Sequential(*self.reward_layers)

        if use_cnn_base:
            self.base_v = MiniGridCNN(layer_dims, use_actions=False)
        else:
            self.base_v = nn.Sequential(torch.nn.Linear(self.layer_dims[0] - ac_len,
                                                        self.layer_dims[1], bias),
                                        torch.nn.PReLU())

        self.value_layers = []
        for i in range(2, len(self.layer_dims)):
            self.value_layers += [torch.nn.Linear(in_features=self.layer_dims[i - 1],
                                                  out_features=self.layer_dims[i],
                                                  bias=bias),
                                  torch.nn.PReLU()]

        self.value_layers += [torch.nn.Linear(in_features=self.layer_dims[-1],
                                              out_features=1,
                                              bias=bias)]
        self.value = nn.Sequential(*self.value_layers)

        # additional head for environment invariance purposes
        if use_cnn_base:
            self.base_c = MiniGridCNN(layer_dims, use_actions=use_actions)
        else:
            self.base_c = nn.Sequential(torch.nn.Linear(self.layer_dims[0],
                                                        self.layer_dims[1], bias),
                                        torch.nn.PReLU())

        self.critic_layers = []
        for i in range(2, len(self.layer_dims)):
            self.critic_layers += [torch.nn.Linear(in_features=self.layer_dims[i - 1],
                                                   out_features=self.layer_dims[i],
                                                   bias=bias),
                                   torch.nn.PReLU()]

        self.critic_layers += [torch.nn.Linear(in_features=self.layer_dims[-1],
                                               out_features=1,
                                               bias=bias)]
        self.critic = nn.Sequential(*self.critic_layers)

        if torch.cuda.is_available():
            self.base.cuda()
            self.base_v.cuda()
            self.base_c.cuda()
            self.reward.cuda()
            self.value.cuda()
            self.critic.cuda()

        # self.module_list = nn.ModuleList([self.base, self.base_v, self.base_c,
        #                                  self.reward, self.value, self.critic])

        self.d_optimizer = Adam(list(self.base.parameters()) + list(self.base_v.parameters())
                                + list(self.base_c.parameters()) + list(self.reward.parameters()) + list(
            self.value.parameters())
                                + list(self.critic.parameters()), lr=self.lr)

    def forward(self, ob, next_ob, ac, lprobs, ob_e, ac_e):
        # forward the nn models
        fitted_value_n = self.value(self.base_v(next_ob))
        fitted_value = self.value(self.base_v(ob))
        reward = self.get_reward(ob, ac)
        # calc discriminator on test environment 
        critic_out = torch.squeeze(self.get_critic(ob_e, ac_e))

        # calculate discriminator probability according to AIRL structure
        qfn = reward + self.gamma * fitted_value_n
        log_p_tau = torch.squeeze(qfn - fitted_value)
        # log probabilities of expert actions under policy
        log_q_tau = lprobs

        # log_pq = torch.log(torch.sum(torch.exp(torch.cat([log_p_tau, log_q_tau], dim=0))))
        # d_out = torch.exp(log_p_tau - log_pq)
        d_out = torch.sigmoid(log_p_tau - log_q_tau)

        return reward, fitted_value, fitted_value_n, d_out, critic_out

    def get_reward(self, ob, ac):
        if self.use_actions and self.use_cnn_base:
            base_out = self.base(ob, ac)
        elif self.use_actions and not self.use_cnn_base:
            if len(ob.shape) != len(ac.shape):
                ac = torch.unsqueeze(ac, -1)
            base_out = self.base(torch.cat([ob, ac], axis=-1))
        else:
            base_out = self.base(ob)

        return self.reward(base_out)

    def get_critic(self, ob, ac):
        if self.use_actions and self.use_cnn_base:
            base_out = self.base_c(ob, ac)
        elif self.use_actions and not self.use_cnn_base:
            if len(ob.shape) != len(ac.shape):
                ac = torch.unsqueeze(ac, -1)
            base_out = self.base_c(torch.cat([ob, ac], axis=-1))
        else:
            base_out = self.base_c(ob)

        return self.critic(base_out)

    def inv_rat_loss(self, env_inv_logits, env_aware_logits, labels):
        env_inv_losses = F.binary_cross_entropy_with_logits(env_inv_logits, labels)
        env_aware_losses = F.binary_cross_entropy_with_logits(env_aware_logits, labels)
        env_inv_loss = torch.mean(env_inv_losses)
        env_aware_loss = torch.mean(env_aware_losses)

        diff_loss = torch.maximum(torch.zeros(1), env_inv_loss - env_aware_loss)

        return env_inv_loss, env_aware_loss, diff_loss

    def compute_loss(self, update_dict, update_dict_e):
        # Define log p(tau) = r(s,a) + gamma * V(s') - V(s)
        reward, fitted_value, fitted_value_n, d_out, c_out = self.forward(update_dict['policy_obs'],
                                                                          update_dict['policy_obs_next'],
                                                                          update_dict['policy_acs'],
                                                                          update_dict['policy_lprobs'],
                                                                          update_dict_e['policy_obs'],
                                                                          update_dict_e['policy_acs'])
        expert_out, policy_out = torch.chunk(d_out, chunks=2, dim=0)
        expert_out_inv, policy_out_inv = torch.chunk(c_out, chunks=2, dim=0)
        labels = torch.cat([torch.ones(expert_out.size()),
                            torch.zeros(policy_out.size())])
        if torch.cuda.is_available():
            labels = labels.cuda()

        env_inv_loss, env_aware_loss, diff_loss = self.inv_rat_loss(d_out, c_out, labels)
        self.loss = env_inv_loss + self.irm_coeff * diff_loss

        output_dict = {}
        output_dict['policy_estimate_inv'] = policy_out_inv
        output_dict['policy_estimate_aw'] = policy_out
        output_dict['expert_estimate_inv'] = expert_out_inv
        output_dict['expert_estimate_aw'] = expert_out
        output_dict['inv_loss'] = env_inv_loss
        output_dict['aw_loss'] = env_aware_loss
        output_dict['diff_loss'] = diff_loss
        output_dict['d_loss'] = self.loss

        return output_dict

    def update(self, loss):
        self.d_optimizer.zero_grad()
        loss.backward()
        self.d_optimizer.step()
