import numpy as np
import torch
from torch._C import ScriptMethod
from functools import partial

import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.distributions.normal import Normal

def init_params(m : nn.Module, gain : float = 1.):
    if m.__class__.__name__ == "Linear" or m.__class__.__name__ == "Convv2D":
        nn.init.orthogonal_(m.weight, gain=gain)
        if m.bias is not None:
            m.bias.data.fill_(0)

class ACModel(nn.Module):
    def __init__(self, 
                obs_space, 
                action_space, 
                reward_dim,
                is_continous=False, 
                include_conv=True, 
                device=None):

        super().__init__()

        self.include_conv = include_conv
        self.is_continous = is_continous
        self.device = device
        self.reward_dim = reward_dim

        if self.include_conv:
            # Define image embedding
            self.image_conv = nn.Sequential(
                nn.Conv2d(3, 16, (2, 2)),
                nn.ReLU(),
                nn.MaxPool2d((2, 2)),
                nn.Conv2d(16, 32, (2, 2)),
                nn.ReLU(),
                nn.Conv2d(32, 64, (2, 2)),
                nn.ReLU()
            )
            n = obs_space.shape[0]
            m = obs_space.shape[1]
            self.image_embedding_size = ((n-1)//2-2)*((m-1)//2-2)*64
        else:
            self.image_embedding_size = obs_space.shape[0]
        
        if self.is_continous:
            self.out_shape = 2*action_space.shape[0]
        else:
            self.out_shape = action_space.n

        # Define actor's model
        self.actor = nn.Sequential(
            nn.Linear(self.image_embedding_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.out_shape)
        )

        # Define critic's model for reward value function
        self.critic = nn.Sequential(
            nn.Linear(self.image_embedding_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.reward_dim)
        )

        # Initialize parameters correctly
        if self.include_conv:
            self.image_conv.apply(partial(init_params, gain=np.sqrt(2)))
        self.actor.apply(partial(init_params, gain=0.01))
        self.critic.apply(partial(init_params, gain=1.))

    def forward(self, obs):
        if self.include_conv:
            x = obs.transpose(1, 3).transpose(2, 3)
            x = self.image_conv(x)
            x = x.reshape(x.shape[0], -1)
        else: 
            x = obs
        embedding = x 
        x = self.actor(embedding)
        
        if self.is_continous:
            half = self.out_shape // 2
            mu_inds = torch.arange(half, device=self.device)
            sigma_inds = torch.arange(half, 2*half, device=self.device)
            mu = torch.index_select(x, -1, mu_inds)
            sigma = torch.index_select(x, -1, sigma_inds)
            sigma = torch.exp(sigma)
            mu = torch.nan_to_num(mu)
            sigma = torch.nan_to_num(sigma)
            dist = Normal(mu, sigma)
        else:
            logits = F.log_softmax(x, dim=1)
            logits = torch.nan_to_num(logits)
            dist = Categorical(logits=logits)

        value = self.critic(embedding)
        
        return dist, value

    def get_policy_params(self):
        """
        Returns the parameters of the actor
        """
        params = []
        if self.include_conv:
            params.extend(self.image_conv.parameters())
        params.extend(self.actor.parameters())
        return params
    
    def zero_policy_grad(self):
        params = self.get_policy_params()
        for p in params:
            p.grad = None

    def check_policy_grad_nan(self):
        params = self.get_policy_params()
        for p in params:
            p.data = torch.nan_to_num(p.data)
            p.grad.data = torch.nan_to_num(p.grad.data)
        return 0
    
    def check_grad_nan(self):
        params = self.parameters()
        for p in params:
            p.data = torch.nan_to_num(p.data)
            p.grad.data = torch.nan_to_num(p.grad.data)
        return 0

