import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from src.utils.math import normal_entropy, normal_log_density



class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, activation='tanh', hidden_size=(64, 64), log_std=0.5):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.is_disc_action = False
        self.activation = torch.tanh
        self.hidden_size = hidden_size
        
        if isinstance(state_dim, int):
            state_dim = (state_dim,)
        if len(state_dim) == 3:
            self.conv = True
        elif len(state_dim) == 1:
            self.conv = False
        else:
            raise NotImplementedError
        
        if activation == 'tanh':
            self.activation = torch.tanh
        elif activation == 'relu':
            self.activation = torch.relu
        elif activation == 'sigmoid':
            self.activation = torch.sigmoid

        if self.conv:
            in_channels = state_dim[0] 
            self.cnn = CNN(in_channels)
            with torch.no_grad():
                dummy = torch.zeros(1, *state_dim)  # shape: (1,C,H,W)
                flat_dim = self.cnn(dummy).size(1)

            last_dim = flat_dim
        self.affine_layers = nn.ModuleList()
        if isinstance(state_dim, tuple):
            last_dim = int(np.prod(state_dim))
        else:           
            last_dim = state_dim
        
        for nh in hidden_size:
            self.affine_layers.append(nn.Linear(last_dim, nh))
            last_dim = nh
            
        self.action_mean = nn.Linear(last_dim, action_dim)
        self.action_mean.weight.data.mul_(0.1)
        self.action_mean.bias.data.mul_(0.0)


        # if not self.is_disc_action:
        self.action_log_std = nn.Parameter(torch.ones(action_dim) * log_std)

    def forward(self, x):
        if self.conv:
            x = self.cnn(x)
        else:
            for affine in self.affine_layers:
                x = self.activation(affine(x))

        action_mean = self.action_mean(x)
        # action_log_std = torch.zeros_like(action_mean)#
        action_log_std = self.action_log_std.expand_as(action_mean)
        action_std = torch.exp(action_log_std)

        return action_mean, action_log_std, action_std

    def select_action(self, x):
        action_mean, _, action_std = self.forward(x)
        action = torch.normal(action_mean, action_std)
        return action

    def get_kl(self, x):
        mean1, log_std1, std1 = self.forward(x)

        mean0 = mean1.detach()
        log_std0 = log_std1.detach()
        std0 = std1.detach()
        kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
        return kl.sum(1, keepdim=True)

    def get_log_prob(self, x, actions, return_entropy=False):
        action_mean, action_log_std, action_std = self.forward(x)
        action_log_probs = normal_log_density(actions, action_mean, action_log_std, action_std)
        if not return_entropy:
            return action_log_probs
        return action_log_probs, normal_entropy(action_std).mean()


    def get_fim(self, x):
        mean, _, _ = self.forward(x)
        cov_inv = self.action_log_std.exp().pow(-2).squeeze(0).repeat(x.size(0))
        param_count = 0
        std_index = 0
        id = 0
        for name, param in self.named_parameters():
            if name == "action_log_std":
                std_id = id
                std_index = param_count
            param_count += param.view(-1).shape[0]
            id += 1
        return cov_inv.detach(), mean, {'std_id': std_id, 'std_index': std_index}

class DiscreteActor(nn.Module):
    def __init__(self, state_dim, action_num, hidden_size=(64, 64), activation='tanh', cnn=None):
        super().__init__()
        
        self.is_disc_action = True

        if isinstance(state_dim, int):
            state_dim = (state_dim,)

        if len(state_dim) == 3:
            if state_dim[0] <= 4:         # CHW
                c, h, w = state_dim
            else:                         # HWC
                h, w, c = state_dim
                state_dim = (c, h, w)
                
        if len(state_dim) == 3:
            self.conv = True
        elif len(state_dim) == 1:
            self.conv = False
        else:
            raise NotImplementedError
        
        if activation == 'tanh':
            self.activation = torch.tanh
        elif activation == 'relu':
            self.activation = torch.relu
        elif activation == 'sigmoid':
            self.activation = torch.sigmoid
        
        if self.conv:
            in_channels = state_dim[0] 
            self.cnn = CNN(in_channels)
            with torch.no_grad():
                device = next(self.cnn.parameters()).device    
                dummy  = torch.zeros(1, *state_dim, device=device)
                # dummy = torch.zeros(1, *state_dim)  # shape: (1,C,H,W)
                flat_dim = self.cnn(dummy).size(1)

            last_dim = flat_dim
            self.affine_layers = nn.ModuleList()   
        else:
            if isinstance(state_dim, tuple):
                last_dim = int(np.prod(state_dim))
            else:
                last_dim = state_dim
        self.affine_layers = nn.ModuleList()
        for nh in hidden_size:
            self.affine_layers.append(nn.Linear(last_dim, nh))
            last_dim = nh

        self.action_head = nn.Linear(last_dim, action_num)
        self.action_head.weight.data.mul_(0.1)
        self.action_head.bias.data.mul_(0.0)

    def forward(self, x):
        # if torch.isnan(x).any() or torch.isinf(x).any():
        #     raise ValueError("NaN/Inf detected at network input")
        if self.conv:
            x = self.cnn(x)
        
        for affine in self.affine_layers:
            x = self.activation(affine(x))
        # eps = 1e-8
        action_prob = torch.softmax(self.action_head(x), dim=-1)# .clamp_min(eps)
        return action_prob

    def select_action(self, x):
        action_prob = self.forward(x)
        action = action_prob.multinomial(1)
        # return action.squeeze(-1).squeeze(-1)
        return action.squeeze()

    def get_kl(self, x):
        action_prob1 = self.forward(x)
        action_prob0 = action_prob1.detach()
        kl = action_prob0 * (torch.log(action_prob0) - torch.log(action_prob1))
        return kl.sum(1, keepdim=True)

    def get_log_prob(self, x, actions):
        action_prob = self.forward(x)
        return torch.log(action_prob.gather(1, actions.long().unsqueeze(1)))

    def get_fim(self, x):
        action_prob = self.forward(x)
        M = action_prob.pow(-1).view(-1).detach()
        return M, action_prob, {}


class Critic(nn.Module):
    def __init__(self, state_dim, hidden_size=(64, 64), activation='tanh', cnn=None):
        super().__init__()
        if isinstance(state_dim, int):
            state_dim = (state_dim,)
        if len(state_dim) == 3:
            self.conv = True
        elif len(state_dim) == 1:
            self.conv = False
        else:
            raise NotImplementedError
        
        if len(state_dim) == 3:
            if state_dim[0] <= 4:         # CHW
                c, h, w = state_dim
            else:                         # HWC
                h, w, c = state_dim
                state_dim = (c, h, w)
        
        if activation == 'tanh':
            self.activation = torch.tanh
        elif activation == 'relu':
            self.activation = torch.relu
        elif activation == 'sigmoid':
            self.activation = torch.sigmoid

        if self.conv:
            in_channels = state_dim[0] 
            self.cnn = CNN(in_channels)
            with torch.no_grad():
                device = next(self.cnn.parameters()).device    
                dummy  = torch.zeros(1, *state_dim, device=device)
                flat_dim = self.cnn(dummy).size(1)

            last_dim = flat_dim
        else:
            if isinstance(state_dim, tuple):
                last_dim = int(np.prod(state_dim))
            else:           
                last_dim = state_dim
        # else:
        self.affine_layers = nn.ModuleList()
        # last_dim = state_dim
        for nh in hidden_size:
            self.affine_layers.append(nn.Linear(last_dim, nh))
            last_dim = nh

        self.value_head = nn.Linear(last_dim, 1)
        self.value_head.weight.data.mul_(0.1)
        self.value_head.bias.data.mul_(0.0)

    def forward(self, x):
        if self.conv:
            x = self.cnn(x)
            
        
        for affine in self.affine_layers:
            x = self.activation(affine(x))
        value = self.value_head(x)
        return value
    


class CNN(nn.Module):
    def __init__(self, in_channels, hidden_size=512):
        super().__init__()
        def ortho(m, gain=1.0):
            nn.init.orthogonal_(m.weight, gain)
            nn.init.constant_(m.bias, 0.)
            return m

        self.conv1 = ortho(nn.Conv2d(in_channels, 32, 8, stride=4), gain=nn.init.calculate_gain('relu')) 

        self.conv2 = ortho(nn.Conv2d(32, 64, 4, stride=2), gain=nn.init.calculate_gain('relu'))
        self.conv3 = ortho(nn.Conv2d(64, 32, 3, stride=1), gain=nn.init.calculate_gain('relu'))

        self.fc = ortho(nn.Linear(32 * 7 * 7, hidden_size), gain=nn.init.calculate_gain('relu'))
        # self.conv = nn.Sequential(
        #     nn.Conv2d(in_channels, 16, kernel_size=8, stride=4),
        #     nn.ReLU(),
        #     nn.Conv2d(16, 32, kernel_size=4, stride=2),
        #     nn.ReLU()
        # )


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        x = F.relu(self.conv1(x))
        
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.reshape(x.size(0), -1)    
        x = F.relu(self.fc(x))
        # x = self.conv(x)
        # x = torch.flatten(x, start_dim=1)
        
        if not torch.isfinite(x).all():
            raise RuntimeError("CNN produced NaN/Inf")
        return x
    