import logging

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


class CardMLP(nn.Module):
    def __init__(self, num_classifier):
        super().__init__()
        self.fc1 = nn.Linear(30, 16)
        self.fc2 = nn.Linear(16, 18)
        self.fc3 = nn.Linear(18, 20)
        self.fc4 = nn.Linear(20, 24)

        self.num_classifier = num_classifier
        self.classifier = nn.Linear(24, self.num_classifier)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=0.25)
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        return self.classifier(x)


import logging

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


class CardMLP(nn.Module):
    def __init__(self, num_classifier):
        super().__init__()
        self.fc1 = nn.Linear(30, 16)
        self.fc2 = nn.Linear(16, 18)
        self.fc3 = nn.Linear(18, 20)
        self.fc4 = nn.Linear(20, 24)

        self.num_classifier = num_classifier
        self.classifier = nn.Linear(24, self.num_classifier)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=0.25)
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        return self.classifier(x)


class MixpulMLP(nn.Module):
    def __init__(self, input_size, hidden_size=256, num_classifier=1):
        super(MixpulMLP, self).__init__()
        self.dim = input_size
        self.num_classifier = num_classifier

        self.fc1 = nn.Linear(input_size, hidden_size)
        self.bn1 = nn.BatchNorm1d(num_features=hidden_size)
        self.relu1 = nn.ReLU()
        #        self.drop1 = nn.Dropout(p=0.5)

        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.bn2 = nn.BatchNorm1d(num_features=hidden_size)
        self.relu2 = nn.ReLU()
        #        self.drop2 = nn.Dropout(p=0.5)

        self.fc3 = nn.Linear(hidden_size, int(hidden_size / 2))
        self.bn3 = nn.BatchNorm1d(num_features=int(hidden_size / 2))
        self.relu3 = nn.ReLU()
        #        self.drop3 = nn.Dropout(p=0.5)

        self.fc4 = nn.Linear(int(hidden_size / 2), int(hidden_size / 4))
        self.bn4 = nn.BatchNorm1d(num_features=int(hidden_size / 4))
        self.relu4 = nn.ReLU()
        #        self.drop4 = nn.Dropout(p=0.5)

        self.fc5 = nn.Linear(int(hidden_size / 4), self.num_classifier)

    def forward(self, x):
        x = x.view(-1, self.dim)

        out = self.fc1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        #        out = self.drop1(out)

        out = self.fc2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        #        out = self.drop2(out)

        out = self.fc3(out)
        out = self.bn3(out)
        out = self.relu3(out)
        #        out = self.drop3(out)
        #
        out = self.fc4(out)
        out = self.bn4(out)
        out = self.relu4(out)
        #        out = self.drop4(out)

        out = self.fc5(out)

        # 重要修改：移除最终的softmax层，改为直接返回logits
        # 原代码使用了F.softmax(out, dim=1)，这与后续的交叉熵损失不兼容
        # return F.softmax(out, dim=1)

        return out  # 直接返回logits，交给损失函数处理激活


