import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import math

# new_target_potential を近似するニューラルネットワークの定義
class ApproxPotential(nn.Module):
    def __init__(self, input_dim, mid_dim=512):
        super(ApproxPotential, self).__init__()
        self.input_dim = input_dim
        self.mid_dim = mid_dim
        self.fc1 = nn.Linear(input_dim, mid_dim)
        self.act1 = torch.tanh
        self.fc2 = nn.Linear(mid_dim, mid_dim//2)
        self.act2 = torch.tanh
        self.fc3 = nn.Linear(mid_dim//2, mid_dim//4)
        self.act3 = torch.tanh
        self.fc4 = nn.Linear(mid_dim//4, 1)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Move the model parameters to the device
        self.to(self.device)

        return

    def forward(self, x):
        # |y| > 5 or |z| > 5 の要素を検出
        # mask = (torch.abs(x) > 5).any(dim=1)

        # 条件を満たさない要素のみ計算
        # result = torch.zeros(x.shape[0], 1).to(device) # 初期化
        x_valid = x # [~mask] # 条件を満たさない要素を抽出
        x_valid = x_valid.float()
        x_valid = self.act1(self.fc1(x_valid))
        x_valid = self.act2(self.fc2(x_valid))
        x_valid = self.act3(self.fc3(x_valid))
        # result[~mask] = self.fc4(x_valid) # 計算結果を代入
        result = self.fc4(x_valid)
        return result
        
    def xavier_init(model):
        """ Xavierの初期化

        Args:
            model (object): モデル
        """
        for name, param in model.named_parameters():
            if name.endswith(".bias"):
                param.data.fill_(0)
            else:
                bound = math.sqrt(6)/math.sqrt(param.shape[0]+param.shape[1])
                param.data.uniform_(-bound, bound)