import torch
from torch import nn
from torch.utils.hooks import unserializable_hook
import os
device = 'cuda' if torch.cuda.is_available() else 'cpu'

@unserializable_hook
def replace_nan_with_zero(grad):
    return torch.nan_to_num(grad, nan=0.0)

class Policy(nn.Module):
    def __init__(self, cfg):
        super(Policy, self).__init__()
        self.cfg = cfg

        self.policy = nn.Sequential(
            nn.Linear(self.cfg.state_dim + self.cfg.index_dim, self.cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(self.cfg.hidden_size, self.cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(self.cfg.hidden_size, self.cfg.state_dim)
        ).to(device)

        for name, param in self.policy.named_parameters():
            torch.nn.init.uniform_(param, -0.2, 0.2)
            param.register_hook(replace_nan_with_zero)

    def forward(self, state_set, offset = None):
        xyz_raw = self.policy(state_set)
        return xyz_raw

def policy_initialization(cfg):
    policy = Policy(cfg) #state -> action
    directory_path = 'Parameter'
    if not os.path.exists(directory_path):

        os.makedirs(directory_path, exist_ok=True)

    torch.save(policy.state_dict(), 'Parameter/initial_policy.pt')

