import torch
import numpy as np
import torch.nn as nn
from configs import Args
import torch.nn.functional as F
import torch.distributions as D


EPS = np.finfo(np.float32).eps
EPS2 = 1e-3


class TanhActor(nn.Module):

    def __init__(self, state_dim, action_dim, hidden_size=256, name='TanhNormalPolicy',
                 mean_range=(-7., 7.), logstd_range=(-5., 2.), eps=EPS, initial_std_scaler=1,
                 kernel_initializer='he_normal', activation_fn=F.relu):  # Changed to torch functions
        super().__init__()
        self._action_dim = action_dim
        self._initial_std_scaler = initial_std_scaler
        self.name = name

        hidden_sizes = (hidden_size, hidden_size)

        # Use PyTorch's nn.Sequential for MLP layers
        layers = []
        in_size = state_dim
        for size in hidden_sizes:
            layers.append(nn.Linear(in_size, size))
            if kernel_initializer == 'he_normal':
                torch.nn.init.kaiming_normal_(layers[-1].weight, nonlinearity='relu')
            layers.append(nn.ReLU())  # Directly use ReLU here
            in_size = size

        self._fc_layers = nn.Sequential(*layers)
        self._fc_mean = nn.Linear(hidden_size, action_dim)
        self._fc_logstd = nn.Linear(hidden_size, action_dim)

        if kernel_initializer == 'he_normal':
          torch.nn.init.kaiming_normal_(self._fc_mean.weight, nonlinearity='linear')
          torch.nn.init.kaiming_normal_(self._fc_logstd.weight, nonlinearity='linear')


        self.mean_min, self.mean_max = mean_range
        self.logstd_min, self.logstd_max = logstd_range
        self.eps = eps
        self.activation_fn = activation_fn


    def forward(self, inputs, step_type=None, network_state=None, training=True): # training added
        del step_type, network_state  # unused

        h = inputs
        h = self._fc_layers(h)

        mean = self._fc_mean(h)
        mean = torch.clamp(mean, self.mean_min, self.mean_max)
        logstd = self._fc_logstd(h)
        logstd = torch.clamp(logstd, self.logstd_min, self.logstd_max)
        std = torch.exp(logstd) * self._initial_std_scaler
        pretanh_action_dist = D.Normal(mean, std)
        pretanh_action = pretanh_action_dist.rsample()  # Use rsample for reparameterization trick
        action = torch.tanh(pretanh_action)
        log_prob, pretanh_log_prob = self.log_prob(pretanh_action_dist, pretanh_action, is_pretanh_action=True)

        return (torch.tanh(mean), action, log_prob), None  # No network state

    def log_prob(self, pretanh_action_dist, action, is_pretanh_action=True):
        if is_pretanh_action:
            pretanh_action = action
            action = torch.tanh(pretanh_action)
        else:
            pretanh_action = torch.atanh(torch.clamp(action, -1 + self.eps, 1 - self.eps))

        pretanh_log_prob = pretanh_action_dist.log_prob(pretanh_action).sum(-1)
        log_prob = pretanh_log_prob - torch.sum(torch.log(1 - action ** 2 + self.eps), dim=-1)

        return log_prob, pretanh_log_prob

    def get_log_prob(self, states, actions):
        """Evaluate log probs for actions conditined on states.
        Args:
            states: A batch of states.
            actions: A batch of actions to evaluate log probs on.
        Returns:
            Log probabilities of actions.
        """
        h = states
        h = self._fc_layers(h)

        mean = self._fc_mean(h)
        mean = torch.clamp(mean, self.mean_min, self.mean_max)
        logstd = self._fc_logstd(h)
        logstd = torch.clamp(logstd, self.logstd_min, self.logstd_max)
        std = torch.exp(logstd) * self._initial_std_scaler
        pretanh_action_dist = D.Normal(mean, std)

        actions = torch.clamp(actions, -1 + self.eps, 1 - self.eps)
        pretanh_actions = torch.atanh(actions)

        pretanh_log_prob = pretanh_action_dist.log_prob(pretanh_actions).sum(-1)  # Sum for multivariate
        log_probs = pretanh_log_prob - torch.sum(torch.log(1 - actions ** 2 + self.eps), dim=-1)
        log_probs = log_probs.unsqueeze(-1)  # To avoid broadcasting
        return log_probs



class DiscreteActor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=256, name='DiscretePolicy',
                 kernel_initializer='he_normal', activation_fn=F.relu): # changed

        super(DiscreteActor, self).__init__()
        # self._input_specs = TensorSpec(state_dim)  # Not needed
        self._action_dim = action_dim
        self.name = name

        hidden_sizes = (hidden_size, hidden_size)

        layers = []
        in_size = state_dim
        for size in hidden_sizes:
            layers.append(nn.Linear(in_size, size))
            if kernel_initializer == 'he_normal':
                torch.nn.init.kaiming_normal_(layers[-1].weight, nonlinearity='relu')
            layers.append(nn.ReLU()) # use ReLU
            in_size = size

        self._fc_layers = nn.Sequential(*layers)
        self._logit_layer = nn.Linear(hidden_size, action_dim)
        if kernel_initializer == 'he_normal':
            torch.nn.init.kaiming_normal_(self._logit_layer.weight, nonlinearity='linear')


    def forward(self, inputs, step_type=None, network_state=None, training=True): # training added
        h = inputs
        h = self._fc_layers(h)
        logits = self._logit_layer(h)
        dist = D.Categorical(logits=logits) # OneHotCategorical replaced
        action = dist.sample()
        greedy_action = F.one_hot(torch.argmax(logits, dim=1), self._action_dim)
        log_prob = dist.log_prob(action)
        action = action.float() # cast to float

        return (greedy_action, action, log_prob), None # No network state

    def get_log_prob(self, states, actions, training=True):
        """Evaluate log probs for actions conditined on states.
        Args:
          states: A batch of states.
          actions: A batch of actions to evaluate log probs on.
        Returns:
          Log probabilities of actions.
        """
        h = states
        h = self._fc_layers(h)
        logits = self._logit_layer(h)
        dist = D.Categorical(logits=logits)  # OneHotCategorical replaced

        log_probs = dist.log_prob(torch.argmax(actions, dim=-1)).unsqueeze(-1)  # Get index from one-hot

        return log_probs


class MixNet(nn.Module):

    def __init__(self, st_dim, n_agents, h_dim=64):
        super().__init__()
        self.n_agents = n_agents
        self.f_v = nn.Linear(st_dim, h_dim)
        self.w_v = nn.Linear(h_dim, n_agents)
        self.b_v = nn.Linear(h_dim, 1)

    def forward(self, states):
        # states = states.flatten(-2, -1)
        x = self.f_v.forward(states).relu()
        w = self.w_v.forward(x).abs()
        b = self.b_v.forward(x)
        return w, b
    

class Critic(nn.Module):

    def __init__(self, st_dim, ob_dim, ac_dim, n_agents, hidden_size=256, output_activation_fn=None, use_last_layer_bias=False,
                 output_dim=None, kernel_initializer='he_normal', name='ValueNetwork'):
        super().__init__()
        self._output_dim = output_dim
        self.name = name

        hidden_sizes = (hidden_size, hidden_size)

        layers = []
        in_size = ob_dim + ac_dim
        for size in hidden_sizes:
            layers.append(nn.Linear(in_size, size))
            if kernel_initializer == 'he_normal':
                torch.nn.init.kaiming_normal_(layers[-1].weight, nonlinearity='relu')
            layers.append(nn.ReLU()) # ReLU
            in_size = size
        self._fc_layers = nn.Sequential(*layers)

        if use_last_layer_bias:
            last_layer_initializer = lambda: torch.nn.init.uniform_(torch.empty(output_dim or 1, in_size), -3e-3, 3e-3)
            self._last_layer = nn.Linear(in_size, output_dim or 1)
            self._last_layer.weight.data = last_layer_initializer()
            self._last_layer.bias.data = last_layer_initializer()

        else:
            self._last_layer = nn.Linear(in_size, output_dim or 1, bias=False)
            if kernel_initializer == 'he_normal':
              torch.nn.init.kaiming_normal_(self._last_layer.weight, nonlinearity='linear')
        self.output_activation_fn = output_activation_fn

        self.mixer = MixNet(st_dim, n_agents) if ac_dim > 0 else None

    def forward(self, obs, states=None, do_sum=True):
        h = self._fc_layers.forward(obs)
        h = self._last_layer.forward(h)

        if self._output_dim is None:
            h = h.squeeze(-1)  # Remove last dimension

        if self.output_activation_fn is not None:
            h = self.output_activation_fn(h) # added output activation
        
        if states is not None and self.mixer is not None:
            w, b = self.mixer.forward(states)
            h = (w * h).sum(-1, keepdim=True) + b

        if do_sum:
            h = h.sum(-1)
        
        return h





class DemoDICE(nn.Module):
    """ Class that implements DemoDICE training """

    def __init__(self, st_dim, ob_dim, ac_dim, n_agents, config: Args):
        super().__init__()
        hidden_size = config.hidden_size
        self.grad_reg_coeffs = config.grad_reg_coeffs
        self.discount = config.gamma
        self.non_expert_regularization = config.alpha + 1.
        self.device = config.device

        self.cost = Critic(st_dim, ob_dim, ac_dim, n_agents, hidden_size=hidden_size,
                           use_last_layer_bias=config.use_last_layer_bias_cost,
                           kernel_initializer=config.kernel_initializer)
        self.critic = Critic(st_dim, ob_dim, 0, n_agents, hidden_size=hidden_size,
                             use_last_layer_bias=config.use_last_layer_bias_critic,
                             kernel_initializer=config.kernel_initializer)
        self.actor = TanhActor(ob_dim, ac_dim, hidden_size=hidden_size, kernel_initializer=config.kernel_initializer)

    def compute_dice_loss(self, init_states, init_obs, expert_transition, union_transition):
        expert_obs = expert_transition["obs"]
        expert_actions = expert_transition["actions"]
        expert_next_obs = expert_transition["next_obs"]
        union_obs = union_transition["obs"]
        union_actions = union_transition["actions"]
        union_next_obs = union_transition["next_obs"]

        expert_states = expert_transition["states"]
        expert_next_states = expert_transition["next_states"]

        union_states = union_transition["states"]
        union_next_states = union_transition["next_states"]

        # define inputs
        expert_inputs = torch.cat([expert_obs, expert_actions], -1)
        union_inputs = torch.cat([union_obs, union_actions], -1)

        # call cost functions
        expert_cost_val = self.cost.forward(expert_inputs, expert_states)
        union_cost_val = self.cost.forward(union_inputs, union_states)
        unif_rand = torch.rand(size=(expert_obs.shape[0], expert_obs.shape[1], 1)).to(self.device)
        mixed_inputs1 = unif_rand * expert_inputs + (1 - unif_rand) * union_inputs
        mixed_inputs2 = unif_rand * torch.index_select(union_inputs, 0, torch.randperm(union_inputs.shape[0]).to(self.device)) + (1 - unif_rand) * union_inputs # use index_select with randperm
        mixed_inputs = torch.cat([mixed_inputs1, mixed_inputs2], 0)

        # gradient penalty for cost
        mixed_inputs.requires_grad_(True)  # Enable gradient tracking for mixed_inputs
        cost_output = self.cost.forward(mixed_inputs)
        cost_output = torch.log(1 / (torch.sigmoid(cost_output) + EPS2) - 1 + EPS2)
        cost_mixed_grad = torch.autograd.grad(
            outputs=cost_output, inputs=mixed_inputs,
            grad_outputs=torch.ones_like(cost_output),
            create_graph=True, retain_graph=True, only_inputs=True)[0] + EPS

        cost_grad_penalty = torch.mean((torch.norm(cost_mixed_grad, dim=-1, keepdim=True) - 1) ** 2)

        # Use PyTorch's built-in minimax loss (binary cross-entropy)
        expert_cost_val_sig = torch.sigmoid(expert_cost_val)
        union_cost_val_sig = torch.sigmoid(union_cost_val)

        cost_loss = F.binary_cross_entropy(union_cost_val_sig, torch.zeros_like(union_cost_val_sig)) + \
                    F.binary_cross_entropy(expert_cost_val_sig, torch.ones_like(expert_cost_val_sig)) + \
                    self.grad_reg_coeffs[0] * cost_grad_penalty
        union_cost = torch.log(1 / (torch.sigmoid(union_cost_val) + EPS2) - 1 + EPS2)

        # nu learning
        init_nu = self.critic.forward(init_obs, init_states)
        union_nu = self.critic.forward(union_obs, union_states)
        union_next_nu = self.critic.forward(union_next_obs, union_next_states)
        union_adv_nu = - union_cost.detach() + self.discount * union_next_nu - union_nu  # detach cost

        non_linear_loss = self.non_expert_regularization * torch.logsumexp(union_adv_nu / self.non_expert_regularization, dim=0)
        linear_loss = (1 - self.discount) * torch.mean(init_nu)
        nu_loss = non_linear_loss + linear_loss

        union_cost_val = self.cost.forward(union_inputs, union_states, do_sum=False)
        union_nu = self.critic.forward(union_obs, union_states, do_sum=False)
        union_next_nu = self.critic.forward(union_next_obs, union_next_states, do_sum=False)
        union_cost = torch.log(1 / (torch.sigmoid(union_cost_val) + EPS2) - 1 + EPS2)
        union_adv_nu = - union_cost.detach() + self.discount * union_next_nu - union_nu  # detach cost

        # weighted BC
        weight = torch.exp((union_adv_nu - torch.max(union_adv_nu)) / self.non_expert_regularization).unsqueeze(-1)
        weight = weight / torch.mean(weight, dim=0, keepdim=True)
        pi_loss = - torch.mean(weight.detach() * self.actor.get_log_prob(union_obs, union_actions))

        # gradient penalty for nu
        if self.grad_reg_coeffs[1] is not None:
            unif_rand2 = torch.rand(size=(expert_obs.shape[0], expert_obs.shape[1], 1)).to(self.device)
            nu_inter = unif_rand2 * expert_obs + (1 - unif_rand2) * union_obs
            nu_next_inter = unif_rand2 * expert_next_obs + (1 - unif_rand2) * union_next_obs

            nu_inter = torch.cat([union_obs, nu_inter, nu_next_inter], 0)
            nu_inter.requires_grad_(True) # requires grad
            nu_output = self.critic.forward(nu_inter)

            nu_mixed_grad = torch.autograd.grad(
                outputs=nu_output, inputs=nu_inter,
                grad_outputs=torch.ones_like(nu_output),
                create_graph=True, retain_graph=True, only_inputs=True
            )[0] + EPS

            nu_grad_penalty = torch.mean((torch.norm(nu_mixed_grad, dim=-1, keepdim=True)) ** 2)  # Removed - 1
            nu_loss += self.grad_reg_coeffs[1] * nu_grad_penalty

        cost_loss.backward()
        nu_loss.backward()
        pi_loss.backward()
        return cost_loss, nu_loss, pi_loss

    def step(self, observation, deterministic: bool = True):
        observation = torch.tensor([observation], dtype=torch.float32)
        with torch.no_grad():
            all_actions, _ = self.actor(observation)
        if deterministic:
            actions = all_actions[0]
        else:
            actions = all_actions[1]
        return actions.cpu().numpy()