import pdb

import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.distributions.normal import Normal
import numpy as np
from torch.autograd import Variable

# class q_MLPAgent(nn.Module):
#     def __init__(self, input_shape, args):
#         super(q_MLPAgent, self).__init__()
#         self.args = args
#
#         self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
#         self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
#         self.fc3 = nn.Linear(args.rnn_hidden_dim, args.n_actions)
#
#         import pdb
#         pdb.set_trace()
#         if args.env in ['cts_matrix_game']:
#             import pdb
#             pdb.set_trace()
#             torch.nn.init.xavier_uniform(self.fc1.weight)
#             torch.nn.init.xavier_uniform(self.fc2.weight)
#             torch.nn.init.xavier_uniform(self.fc3.weight)
#
#         self.agent_return_logits = getattr(self.args, "agent_return_logits", False)
#
#     def init_hidden(self):
#         # make hidden states on same device as model
#         return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()
#
#     def forward(self, inputs, hidden_state, actions=None):
#         x = F.relu(self.fc1(inputs))
#         x = F.relu(self.fc2(x))
#         if self.agent_return_logits:
#             actions = self.fc3(x)
#         else:
#             actions = F.tanh(self.fc3(x))
#         return {"actions": actions, "hidden_state": hidden_state}


class MLPAgent(nn.Module):
    def __init__(self, input_shape, args):
        super(MLPAgent, self).__init__()
        self.args = args

        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.fc3 = nn.Linear(args.rnn_hidden_dim, args.n_actions)

        if args.env in ['cts_matrix_game']:
            import pdb
            pdb.set_trace()
            torch.nn.init.xavier_uniform(self.fc1.weight)
            torch.nn.init.xavier_uniform(self.fc2.weight)
            torch.nn.init.xavier_uniform(self.fc3.weight)

        self.agent_return_logits = getattr(self.args, "agent_return_logits", False)

    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()

    def forward(self, inputs, hidden_state, actions=None):
        x = F.relu(self.fc1(inputs))
        x = F.relu(self.fc2(x))
        if self.agent_return_logits:
            actions = self.fc3(x)
        else:
            actions = F.tanh(self.fc3(x))
        return {"actions": actions, "hidden_state": hidden_state}

LOG_STD_MAX = 2
LOG_STD_MIN = -20

class MLPAgent_Gaussian_nonn(nn.Module):
    def __init__(self, input_shape, args):
        super(MLPAgent_Gaussian_nonn, self).__init__()
        self.args = args

        # self.mu = Variable(torch.zeros(1), requires_grad=True)
        # self.log_sigma = Variable(torch.zeros(1), requires_grad=True)
        #  torch.nn.Parameter(torch.Tensor([10]))
        self.mu = torch.nn.Parameter(torch.Tensor([0, 0]))
        self.log_sigma = torch.nn.Parameter(torch.Tensor([-1, -1]))
        # self.log_sigma = torch.nn.Parameter(torch.Tensor([0, 0]))

        print("init mu: ", self.mu)
        print("init log_sigma: ", self.log_sigma)

        # self.mu2 = torch.nn.Parameter(torch.Tensor([0, 0]))
        # self.log_sigma2 = torch.nn.Parameter(torch.Tensor([0, 0]))


        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)

        self.agent_return_logits = getattr(self.args, "agent_return_logits", False)
        # pdb.set_trace()
        # self.args.action_spaces[0].high[0]
        self.act_limit = self.args.action_spaces[0].high[0]  #self.args.action_spaces

    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()

    def forward(self, inputs, hidden_state, actions=None):

        mu = self.mu.expand(int(inputs.shape[0]/2), 2)
        log_sigma = self.log_sigma.expand(int(inputs.shape[0]/2), 2)
        sigma = torch.exp(torch.clamp(log_sigma+1e-6, LOG_STD_MIN, LOG_STD_MAX))
        pi_distribution = Normal(mu, sigma)
        pi_action = pi_distribution.rsample()

        logp_pi = pi_distribution.log_prob(pi_action) #.sum(axis=1)
        logp_pi -= (2 * (np.log(2) - pi_action - F.softplus(-2 * pi_action))) #.sum(axis=1)

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        mu_action = self.act_limit * torch.tanh(mu)


        return {"actions": pi_action, "mu_actions": mu_action, "logp_pi":logp_pi, "hidden_state": hidden_state}


class MLPAgent_Gaussian(nn.Module):
    def __init__(self, input_shape, args):
        super(MLPAgent_Gaussian, self).__init__()
        self.args = args

        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.fc3_mu = nn.Linear(args.rnn_hidden_dim, args.n_actions)
        self.fc3_log_sigma = nn.Linear(args.rnn_hidden_dim, args.n_actions)

        if args.env in ['cts_matrix_game']:
            self.args.cts_ver = 1
            if self.args.cts_ver ==1 :
                torch.nn.init.uniform_(self.fc1.weight, a=-0.01, b=0.01)
                torch.nn.init.uniform_(self.fc2.weight, a=-0.01, b=0.01)
                torch.nn.init.uniform_(self.fc3_mu.weight, a=-0.01, b=0.01)
                torch.nn.init.uniform_(self.fc3_log_sigma.weight, a=-0.01, b=0.01)
            elif self.args.cts_ver == 2:
                self.mu = Variable(torch.zeros(1), requires_grad=True)
                self.sigma = Variable(torch.ones(1), requires_grad=True)

            # torch.nn.init.xavier_uniform(self.fc1.weight)
            # torch.nn.init.xavier_uniform(self.fc2.weight)
            # torch.nn.init.xavier_uniform(self.fc3_mu.weight)
            # torch.nn.init.xavier_uniform(self.fc3_log_sigma.weight)


        self.agent_return_logits = getattr(self.args, "agent_return_logits", False)
        # pdb.set_trace()
        # self.args.action_spaces[0].high[0]
        self.act_limit = self.args.action_spaces[0].high[0]  #self.args.action_spaces

    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()

    def forward(self, inputs, hidden_state, actions=None):
        x = F.relu(self.fc1(inputs))
        x = F.relu(self.fc2(x))

        mu = self.fc3_mu(x)
        log_sigma = self.fc3_log_sigma(x + 1E-6)
        sigma = torch.exp(torch.clamp(log_sigma, LOG_STD_MIN, LOG_STD_MAX))
        pi_distribution = Normal(mu, sigma)
        pi_action = pi_distribution.rsample()

        logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
        logp_pi -= (2 * (np.log(2) - pi_action - F.softplus(-2 * pi_action))).sum(axis=1)

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        mu_action = self.act_limit * torch.tanh(mu)

        return {"actions": pi_action, "mu_actions": mu_action, "logp_pi":logp_pi, "hidden_state": hidden_state, "log_sigma":log_sigma}