import logging

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
from torchvision.models import resnet50



class MyClassifier(nn.Module):
    def zero_one_loss(self, h, t, is_logistic=False):
        self.eval()
        positive = 1
        negative = 0 if is_logistic else -1

        n_p = (t == positive).sum()
        n_n = (t == negative).sum()
        size = n_p + n_n

        n_pp = (h == positive).sum()
        t_p = ((h == positive) * (t == positive)).sum()
        t_n = ((h == negative) * (t == negative)).sum()
        f_p = n_n - t_n
        f_n = n_p - t_p

        # print("size:{0},t_p:{1},t_n:{2},f_p:{3},f_n:{4}".format(
        #     size, t_p, t_n, f_p, f_n))

        presicion = (0.0 if t_p == 0 else t_p / (t_p + f_p))
        recall = (0.0 if t_p == 0 else t_p / (t_p + f_n))

        return presicion, recall, 1 - (t_p + t_n) / size, n_pp

    def error(self, DataLoader, is_logistic=False):
        targets_all = np.array([])
        prediction_all = np.array([])
        self.eval()
        for data, target in DataLoader:
            data = data.to(device, non_blocking=True)
            t = target.detach().cpu().numpy()
            size = len(t)
            if is_logistic:
                h = np.reshape(
                    torch.sigmoid(self(data)).detach().cpu().numpy(), size)
                h = np.where(h > 0.5, 1, 0).astype(np.int32)
            else:
                h = np.reshape(
                    torch.sign(self(data)).detach().cpu().numpy(), size)

            targets_all = np.hstack((targets_all, t))
            prediction_all = np.hstack((prediction_all, h))

        return self.zero_one_loss(prediction_all, targets_all, is_logistic)

    def evalution_with_density(self, DataLoader, prior):
        targets_all = np.array([])
        prediction_all = np.array([])
        self.eval()
        for data, target in DataLoader:
            data = data.to(device)
            t = target.detach().cpu().numpy()
            size = len(t)
            # get f_x
            h = np.reshape(self(data).detach().cpu().numpy(), size)
            # predict with density ratio and threshold
            h = self.predict_with_density_threshold(h, target, prior)

            targets_all = np.hstack((targets_all, t))
            prediction_all = np.hstack((prediction_all, h))

        return self.zero_one_loss(prediction_all, targets_all)

    def predict_with_density_threshold(self, f_x, target, prior):
        density_ratio = f_x / prior
        # ascending sort
        sorted_density_ratio = np.sort(density_ratio)
        size = len(density_ratio)

        n_pi = int(size * prior)
        # print("size: ", size)
        # print("density_ratio shape: ", density_ratio.shape)
        # print("n in test data: ", n_pi)
        # print("n in real data: ", (target == 1).sum())
        threshold = (sorted_density_ratio[size - n_pi] +
                     sorted_density_ratio[size - n_pi - 1]) / 2
        # print("threshold:", threshold)
        h = np.sign(density_ratio - threshold).astype(np.int32)
        return h



class MultiLayerPerceptron(MyClassifier, nn.Module):
    def __init__(self, dim, act_func='relu'):
        super(MultiLayerPerceptron, self).__init__()

        self.input_dim = dim
        self.num_classifier = 1

        self.l1 = nn.Linear(dim, 300, bias=False)
        self.b1 = nn.BatchNorm1d(300)
        # self.b1 = nn.LayerNorm(300)
        self.l2 = nn.Linear(300, 300, bias=False)
        self.b2 = nn.BatchNorm1d(300)
        # self.b2 = nn.LayerNorm(300)
        self.l3 = nn.Linear(300, 300, bias=False)
        self.b3 = nn.BatchNorm1d(300)
        # self.b3 = nn.LayerNorm(300)
        self.l4 = nn.Linear(300, 300, bias=False)
        self.b4 = nn.BatchNorm1d(300)
        # self.b4 = nn.LayerNorm(300)
        self.l5 = nn.Linear(300, self.num_classifier)
        if act_func == 'relu':
            self.af = F.relu
        elif act_func == 'softsign':
            self.af = F.softsign

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        h = self.l1(x)
        h = self.b1(h)
        h = self.af(h)
        h = self.l2(h)
        h = self.b2(h)
        h = self.af(h)
        h = self.l3(h)
        h = self.b3(h)
        h = self.af(h)
        h = self.l4(h)
        h = self.b4(h)
        h = self.af(h)
        h = self.l5(h)
        return h

    def get_features(self, x):
        """Extract features from the penultimate layer (before l5)"""
        x = x.view(-1, self.input_dim)
        h = self.l1(x)
        h = self.b1(h)
        h = self.af(h)
        h = self.l2(h)
        h = self.b2(h)
        h = self.af(h)
        h = self.l3(h)
        h = self.b3(h)
        h = self.af(h)
        h = self.l4(h)
        h = self.b4(h)
        h = self.af(h)
        # Return features before the final classifier layer
        return h


# class CardMLP(nn.Module):
#     def __init__(self, num_classifier):
#         super().__init__()
#         self.fc1 = nn.Linear(30, 16)
#         self.fc2 = nn.Linear(16, 18)
#         self.fc3 = nn.Linear(18, 20)
#         self.fc4 = nn.Linear(20, 24)
#
#         self.num_classifier = num_classifier
#         self.classifier = nn.Linear(24, self.num_classifier)
#
#     def forward(self, x):
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = F.dropout(x, p=0.25)
#         x = F.relu(self.fc3(x))
#         x = F.relu(self.fc4(x))
#         return self.classifier(x)
#
#
# import logging
#
# import pandas as pd
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.utils.data as data_utils
# from sklearn.model_selection import train_test_split
# from sklearn.preprocessing import StandardScaler
#
#
# class CardMLP(nn.Module):
#     def __init__(self, num_classifier):
#         super().__init__()
#         self.fc1 = nn.Linear(30, 16)
#         self.fc2 = nn.Linear(16, 18)
#         self.fc3 = nn.Linear(18, 20)
#         self.fc4 = nn.Linear(20, 24)
#
#         self.num_classifier = num_classifier
#         self.classifier = nn.Linear(24, self.num_classifier)
#
#     def forward(self, x):
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = F.dropout(x, p=0.25)
#         x = F.relu(self.fc3(x))
#         x = F.relu(self.fc4(x))
#         return self.classifier(x)
#
#
# class MixpulMLP(nn.Module):
#     def __init__(self, input_size, hidden_size=256, num_classifier=1):
#         super(MixpulMLP, self).__init__()
#         self.dim = input_size
#         self.num_classifier = num_classifier
#
#         self.fc1 = nn.Linear(input_size, hidden_size)
#         self.bn1 = nn.BatchNorm1d(num_features=hidden_size)
#         self.relu1 = nn.ReLU()
#         #        self.drop1 = nn.Dropout(p=0.5)
#
#         self.fc2 = nn.Linear(hidden_size, hidden_size)
#         self.bn2 = nn.BatchNorm1d(num_features=hidden_size)
#         self.relu2 = nn.ReLU()
#         #        self.drop2 = nn.Dropout(p=0.5)
#
#         self.fc3 = nn.Linear(hidden_size, int(hidden_size / 2))
#         self.bn3 = nn.BatchNorm1d(num_features=int(hidden_size / 2))
#         self.relu3 = nn.ReLU()
#         #        self.drop3 = nn.Dropout(p=0.5)
#
#         self.fc4 = nn.Linear(int(hidden_size / 2), int(hidden_size / 4))
#         self.bn4 = nn.BatchNorm1d(num_features=int(hidden_size / 4))
#         self.relu4 = nn.ReLU()
#         #        self.drop4 = nn.Dropout(p=0.5)
#
#         self.fc5 = nn.Linear(int(hidden_size / 4), self.num_classifier)
#
#     def forward(self, x):
#         x = x.view(-1, self.dim)
#
#         out = self.fc1(x)
#         out = self.bn1(out)
#         out = self.relu1(out)
#         #        out = self.drop1(out)
#
#         out = self.fc2(out)
#         out = self.bn2(out)
#         out = self.relu2(out)
#         #        out = self.drop2(out)
#
#         out = self.fc3(out)
#         out = self.bn3(out)
#         out = self.relu3(out)
#         #        out = self.drop3(out)
#         #
#         out = self.fc4(out)
#         out = self.bn4(out)
#         out = self.relu4(out)
#         #        out = self.drop4(out)
#
#         out = self.fc5(out)
#
#         # 重要修改：移除最终的softmax层，改为直接返回logits
#         # 原代码使用了F.softmax(out, dim=1)，这与后续的交叉熵损失不兼容
#         # return F.softmax(out, dim=1)
#
#         return out  # 直接返回logits，交给损失函数处理激活
#
#
