import torch
from torch import nn
import utils
import wandb

import torch.nn.functional as F


class DoubleQCritic(nn.Module):
    """Critic network, employs double Q-learning."""

    def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth):
        super().__init__()

        self.Q1 = self._build_q_network(obs_dim, action_dim, hidden_dim, hidden_depth)
        self.Q2 = self._build_q_network(obs_dim, action_dim, hidden_dim, hidden_depth)

        self.outputs = {}
        self.apply(utils.weight_init)

    def _build_q_network(self, obs_dim, action_dim, hidden_dim, hidden_depth):
        """Helper function to build a Q network."""
        return utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth)

    def forward(self, obs, action):
        """Compute Q values for both Q networks."""
        assert obs.size(0) == action.size(0), "Observation and action batch sizes must match"

        obs_action = torch.cat([obs, action], dim=-1)
        q1 = self.Q1(obs_action)
        q2 = self.Q2(obs_action)

        # Store outputs for logging
        self.outputs['q1'] = q1
        self.outputs['q2'] = q2

        return q1, q2

    def log(self):
        """Log Q-values and network parameters to both logger and wandb."""
        for name, value in self.outputs.items():
            wandb.log({f'train_critic/{name}_hist_mean': value.mean().item()})  # Log mean Q-value to wandb

        self._log_network_params(self.Q1, 'q1')
        self._log_network_params(self.Q2, 'q2')

    def _log_network_params(self, network, name):
        """Helper function to log network parameters."""
        for i, layer in enumerate(network):
            if isinstance(layer, nn.Linear):
                wandb.log({
                    f'train_critic/{name}_fc{i}_weight_mean': layer.weight.mean().item(),
                    f'train_critic/{name}_fc{i}_bias_mean': layer.bias.mean().item()
                })


class SafetyQCritic(nn.Module):
    """Critic network, employs double Q-learning."""

    def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth):
        super().__init__()

        self.Q1 = self._build_q_network(obs_dim, action_dim, hidden_dim, hidden_depth)
        self.Q2 = self._build_q_network(obs_dim, action_dim, hidden_dim, hidden_depth)
        self.Q3 = self._build_q_network(obs_dim, action_dim, hidden_dim, hidden_depth)
        self.Q4 = self._build_q_network(obs_dim, action_dim, hidden_dim, hidden_depth)

        self.outputs = {}
        self.apply(utils.weight_init)

    def _build_q_network(self, obs_dim, action_dim, hidden_dim, hidden_depth):
        """Helper function to build a Q network."""
        return utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth)

    def forward(self, obs, action):
        """Compute Q values for both Q networks."""
        assert obs.size(0) == action.size(0), "Observation and action batch sizes must match"

        obs_action = torch.cat([obs, action], dim=-1)
        q1 = self.Q1(obs_action)
        q2 = self.Q2(obs_action)
        q3 = self.Q3(obs_action)
        q4 = self.Q4(obs_action)

        # Store outputs for logging
        self.outputs['q1'] = q1
        self.outputs['q2'] = q2
        self.outputs['q3'] = q3
        self.outputs['q4'] = q4

        return q1, q2, q3, q4

    def log(self):
        """Log Q-values and network parameters to both logger and wandb."""
        for name, value in self.outputs.items():
            wandb.log({f'train_safety_critic/{name}_hist_mean': value.mean().item()})  # Log mean Q-value to wandb

        self._log_network_params(self.Q1, 'q1')
        self._log_network_params(self.Q2, 'q2')
        self._log_network_params(self.Q3, 'q3')
        self._log_network_params(self.Q4, 'q4')

    def _log_network_params(self, network, name):
        """Helper function to log network parameters."""
        for i, layer in enumerate(network):
            if isinstance(layer, nn.Linear):
                wandb.log({
                    f'train_safety_critic/{name}_fc{i}_weight_mean': layer.weight.mean().item(),
                    f'train_safety_critic/{name}_fc{i}_bias_mean': layer.bias.mean().item()
                })


class SafetyCritic(nn.Module):
    """Safety Critic Network for estimating Long Term Costs (Mean and Variance)"""

    def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth):
        super().__init__()

        self.QC = utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth)
        self.VC = utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth)

        self.outputs = dict()
        self.apply(utils.weight_init)

    def forward(self, obs, action):
        assert obs.size(0) == action.size(0)

        obs_action = torch.cat([obs, action], dim=-1)
        qc = self.QC(obs_action)
        vc = self.VC(obs_action)

        self.outputs["qc"] = qc
        self.outputs["vc"] = vc

        return qc, vc

    def log(self):
        for k, v in self.outputs.items():
            wandb.log({f"train_safety_critic/{k}_hist_mean": v.mean().item()})

        assert len(self.QC) == len(self.VC)
        for i, (m1, m2) in enumerate(zip(self.QC, self.VC)):
            assert type(m1) == type(m2)
            if type(m1) is nn.Linear:
                wandb.log({
                    f"train_safety_critic/qc_fc{i}_weight_mean": m1.weight.mean().item(),
                    f"train_safety_critic/qc_fc{i}_bias_mean": m1.bias.mean().item(),
                    f"train_safety_critic/vc_fc{i}_weight_mean": m2.weight.mean().item(),
                    f"train_safety_critic/vc_fc{i}_bias_mean": m2.bias.mean().item()
                })


class SafetyCriticGMM(nn.Module):
    """Safety Critic Network for estimating Long Term Costs (Mixture of Gaussians for Mean and Variance)"""

    def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth, num_components=4):
        super().__init__()
        self.num_components = num_components

        self.QCs = utils.mlp(obs_dim + action_dim, hidden_dim, num_components, hidden_depth)
        self.VCs = utils.mlp(obs_dim + action_dim, hidden_dim, num_components, hidden_depth)
        self.Weights = utils.mlp(obs_dim + action_dim, hidden_dim, num_components, hidden_depth)

        self.outputs = dict()
        self.apply(utils.weight_init)

    def forward(self, obs, action):
        assert obs.size(0) == action.size(0)

        obs_action = torch.cat([obs, action], dim=-1)
        qcs = self.QCs(obs_action)
        vcs = self.VCs(obs_action)
        weights = F.softmax(self.Weights(obs_action), dim=-1)

        self.outputs["qc"] = qcs
        self.outputs["vc"] = vcs

        return qcs, vcs, weights
