import torch
import numpy as np
import torch.nn as nn
from configs import Args
import torch.nn.functional as F
import torch.distributions as D


EPS = np.finfo(np.float32).eps
EPS2 = 1e-3


class DiscreteActor(nn.Module):

    def __init__(self, state_dim, action_dim, hidden_size=256, name='DiscretePolicy',
                 kernel_initializer='he_normal', activation_fn=F.relu): # changed
        super().__init__()
        self._action_dim = action_dim

        hidden_sizes = (hidden_size, hidden_size)

        layers = []
        in_size = state_dim
        for size in hidden_sizes:
            layers.append(nn.Linear(in_size, size))
            if kernel_initializer == 'he_normal':
                torch.nn.init.kaiming_normal_(layers[-1].weight, nonlinearity='relu')
            layers.append(nn.ReLU()) # use ReLU
            in_size = size

        self._fc_layers = nn.Sequential(*layers)
        self._logit_layer = nn.Linear(hidden_size, action_dim)
        if kernel_initializer == 'he_normal':
            torch.nn.init.kaiming_normal_(self._logit_layer.weight, nonlinearity='linear')

    def forward(self, inputs):
        h = self._fc_layers.forward(inputs)
        logits = self._logit_layer.forward(h)
        return logits

    def get_log_prob(self, states, avails, actions):
        """Evaluate log probs for actions conditined on states.
        Args:
          states: A batch of states.
          actions: A batch of actions to evaluate log probs on.
        Returns:
          Log probabilities of actions.
        """
        logits = self.forward(states)  # Forward pass to get logits
        logits = logits + avails.log()  # Add availability
        logits = logits - logits.logsumexp(-1, True)  # Normalize logits
        dist = D.Categorical(logits=logits)  # OneHotCategorical replaced
        log_probs = dist.log_prob(actions)
        return log_probs


class MixNet(nn.Module):

    def __init__(self, st_dim, n_agents, h_dim=64):
        super().__init__()
        self.n_agents = n_agents
        self.f_v = nn.Linear(st_dim, h_dim)
        self.w_v = nn.Linear(h_dim, n_agents)
        self.b_v = nn.Linear(h_dim, 1)

    def forward(self, states):
        # states = states.flatten(-2, -1)
        x = self.f_v.forward(states).relu()
        w = self.w_v.forward(x).abs()
        b = self.b_v.forward(x)
        return w, b
    

class Critic(nn.Module):

    def __init__(self, st_dim, ob_dim, ac_dim, n_agents, hidden_size=256, output_activation_fn=None, use_last_layer_bias=False,
                 output_dim=None, kernel_initializer='he_normal', name='ValueNetwork'):
        super().__init__()
        self._output_dim = output_dim
        self.name = name

        hidden_sizes = (hidden_size, hidden_size)

        layers = []
        in_size = ob_dim + ac_dim
        for size in hidden_sizes:
            layers.append(nn.Linear(in_size, size))
            if kernel_initializer == 'he_normal':
                torch.nn.init.kaiming_normal_(layers[-1].weight, nonlinearity='relu')
            layers.append(nn.ReLU()) # ReLU
            in_size = size
        self._fc_layers = nn.Sequential(*layers)

        if use_last_layer_bias:
            last_layer_initializer = lambda: torch.nn.init.uniform_(torch.empty(output_dim or 1, in_size), -3e-3, 3e-3)
            self._last_layer = nn.Linear(in_size, output_dim or 1)
            self._last_layer.weight.data = last_layer_initializer()
            self._last_layer.bias.data = last_layer_initializer()

        else:
            self._last_layer = nn.Linear(in_size, output_dim or 1, bias=False)
            if kernel_initializer == 'he_normal':
              torch.nn.init.kaiming_normal_(self._last_layer.weight, nonlinearity='linear')
        self.output_activation_fn = output_activation_fn

        self.mixer = MixNet(st_dim, n_agents) if ac_dim > 0 else None

    def forward(self, obs, states=None, do_sum=True):
        h = self._fc_layers.forward(obs)
        h = self._last_layer.forward(h)

        if self._output_dim is None:
            h = h.squeeze(-1)  # Remove last dimension

        if self.output_activation_fn is not None:
            h = self.output_activation_fn(h) # added output activation
        
        if states is not None and self.mixer is not None:
            w, b = self.mixer.forward(states)
            h = (w * h).sum(-1, keepdim=True) + b

        if do_sum:
            h = h.sum(-1)
        
        return h





class DemoDICE(nn.Module):
    """ Class that implements DemoDICE training """

    def __init__(self, st_dim, ob_dim, ac_dim, n_agents, config: Args):
        super().__init__()
        hidden_size = config.hidden_size
        self.grad_reg_coeffs = config.grad_reg_coeffs
        self.discount = config.gamma
        self.non_expert_regularization = config.alpha + 1.
        self.device = config.device

        self.cost = Critic(st_dim, ob_dim, ac_dim, n_agents, hidden_size=hidden_size,
                           use_last_layer_bias=config.use_last_layer_bias_cost,
                           kernel_initializer=config.kernel_initializer)
        self.critic = Critic(st_dim, ob_dim, 0, n_agents, hidden_size=hidden_size,
                             use_last_layer_bias=config.use_last_layer_bias_critic,
                             kernel_initializer=config.kernel_initializer)
        self.actor = DiscreteActor(ob_dim, ac_dim, hidden_size=hidden_size, kernel_initializer=config.kernel_initializer)

    def compute_dice_loss(self, init_states, init_obs, expert_transition, union_transition):
        expert_obs = expert_transition["obs"]
        expert_actions = expert_transition["actions"]
        expert_next_obs = expert_transition["next_obs"]
        union_obs = union_transition["obs"]
        union_actions = union_transition["actions"]
        union_next_obs = union_transition["next_obs"]

        # print("expert_transition:", expert_transition)
        # print("union_transition:", union_transition)

        expert_states = expert_transition["states"]
        expert_next_states = expert_transition["next_states"]

        union_states = union_transition["states"]
        union_next_states = union_transition["next_states"]

        union_avails = union_transition["avails"]


        expert_actions_onehot = torch.nn.functional.one_hot(expert_actions, num_classes=self.actor._action_dim).float()
        union_actions_onehot = torch.nn.functional.one_hot(union_actions, num_classes=self.actor._action_dim).float()


        # define inputs
        expert_inputs = torch.cat([expert_obs, expert_actions_onehot], -1)
        union_inputs = torch.cat([union_obs, union_actions_onehot], -1)

        # call cost functions
        expert_cost_val = self.cost.forward(expert_inputs, expert_states)
        union_cost_val = self.cost.forward(union_inputs, union_states)
        unif_rand = torch.rand(size=(expert_obs.shape[0], expert_obs.shape[1], 1)).to(self.device)
        mixed_inputs1 = unif_rand * expert_inputs + (1 - unif_rand) * union_inputs
        mixed_inputs2 = unif_rand * torch.index_select(union_inputs, 0, torch.randperm(union_inputs.shape[0]).to(self.device)) + (1 - unif_rand) * union_inputs # use index_select with randperm
        mixed_inputs = torch.cat([mixed_inputs1, mixed_inputs2], 0)

        # gradient penalty for cost
        mixed_inputs.requires_grad_(True)  # Enable gradient tracking for mixed_inputs
        cost_output = self.cost.forward(mixed_inputs)
        cost_output = torch.log(1 / (torch.sigmoid(cost_output) + EPS2) - 1 + EPS2)
        cost_mixed_grad = torch.autograd.grad(
            outputs=cost_output, inputs=mixed_inputs,
            grad_outputs=torch.ones_like(cost_output),
            create_graph=True, retain_graph=True, only_inputs=True)[0] + EPS

        cost_grad_penalty = torch.mean((torch.norm(cost_mixed_grad, dim=-1, keepdim=True) - 1) ** 2)

        # Use PyTorch's built-in minimax loss (binary cross-entropy)
        expert_cost_val_sig = torch.sigmoid(expert_cost_val)
        union_cost_val_sig = torch.sigmoid(union_cost_val)

        cost_loss = F.binary_cross_entropy(union_cost_val_sig, torch.zeros_like(union_cost_val_sig)) + \
                    F.binary_cross_entropy(expert_cost_val_sig, torch.ones_like(expert_cost_val_sig)) + \
                    self.grad_reg_coeffs[0] * cost_grad_penalty
        union_cost = torch.log(1 / (torch.sigmoid(union_cost_val) + EPS2) - 1 + EPS2)

        # nu learning
        init_nu = self.critic.forward(init_obs, init_states)
        union_nu = self.critic.forward(union_obs, union_states)
        union_next_nu = self.critic.forward(union_next_obs, union_next_states)
        union_adv_nu = - union_cost.detach() + self.discount * union_next_nu - union_nu  # detach cost

        non_linear_loss = self.non_expert_regularization * torch.logsumexp(union_adv_nu / self.non_expert_regularization, dim=0)
        linear_loss = (1 - self.discount) * torch.mean(init_nu)
        nu_loss = non_linear_loss + linear_loss

        union_cost_val = self.cost.forward(union_inputs, union_states, do_sum=False)
        union_nu = self.critic.forward(union_obs, union_states, do_sum=False)
        union_next_nu = self.critic.forward(union_next_obs, union_next_states, do_sum=False)
        union_cost = torch.log(1 / (torch.sigmoid(union_cost_val) + EPS2) - 1 + EPS2)
        union_adv_nu = - union_cost.detach() + self.discount * union_next_nu - union_nu  # detach cost

        # weighted BC
        weight = torch.exp((union_adv_nu - torch.max(union_adv_nu)) / self.non_expert_regularization).unsqueeze(-1)
        weight = weight / torch.mean(weight, dim=0, keepdim=True)
        weight = weight.squeeze(-1)

        pi_loss = - torch.mean(weight.detach() * self.actor.get_log_prob(union_obs, union_avails, union_actions))

        # gradient penalty for nu
        if self.grad_reg_coeffs[1] is not None:
            unif_rand2 = torch.rand(size=(expert_obs.shape[0], expert_obs.shape[1], 1)).to(self.device)
            nu_inter = unif_rand2 * expert_obs + (1 - unif_rand2) * union_obs
            nu_next_inter = unif_rand2 * expert_next_obs + (1 - unif_rand2) * union_next_obs

            nu_inter = torch.cat([union_obs, nu_inter, nu_next_inter], 0)
            nu_inter.requires_grad_(True) # requires grad
            nu_output = self.critic.forward(nu_inter)

            nu_mixed_grad = torch.autograd.grad(
                outputs=nu_output, inputs=nu_inter,
                grad_outputs=torch.ones_like(nu_output),
                create_graph=True, retain_graph=True, only_inputs=True
            )[0] + EPS

            nu_grad_penalty = torch.mean((torch.norm(nu_mixed_grad, dim=-1, keepdim=True)) ** 2)  # Removed - 1
            nu_loss += self.grad_reg_coeffs[1] * nu_grad_penalty

        cost_loss.backward()
        nu_loss.backward()
        pi_loss.backward()
        return cost_loss, nu_loss, pi_loss

    def step(self, observation, deterministic: bool = True):
        observation = torch.tensor([observation], dtype=torch.float32)
        with torch.no_grad():
            all_actions, _ = self.actor(observation)
        if deterministic:
            actions = all_actions[0]
        else:
            actions = all_actions[1]
        return actions.cpu().numpy()