import torch
import torch.nn as nn
from tqdm import tqdm

def pearson_correlation(x, y):
    # 计算向量的均值
    x_mean = torch.mean(x)
    y_mean = torch.mean(y)

    # 计算分子部分：协方差
    numerator = torch.sum((x - x_mean) * (y - y_mean))

    # 计算分母部分：两个向量标准差的乘积
    denominator = torch.sqrt(torch.sum((x - x_mean) ** 2)) * torch.sqrt(
        torch.sum((y - y_mean) ** 2)
    )

    # 计算皮尔逊相关系数
    correlation = numerator / denominator

    return correlation

def model_inversion_stealing(aux_model, target, input_size, x_true,
                             lambda_l2=1, main_iters=1000, input_iters=100, model_iters=100,
                             device=torch.device('cuda'), show_tqdm=False):
    x_pred = torch.empty(input_size, device=device).fill_(0.5).requires_grad_(True)
    input_opt = torch.optim.Adam([x_pred], lr=0.001, amsgrad=True)
    model_opt = torch.optim.Adam(aux_model.parameters(), lr=0.001, amsgrad=True)
    mse = torch.nn.MSELoss()

    main_iter_range = tqdm(range(main_iters), desc="Main Iterations") if show_tqdm else range(main_iters)
    for main_iter in main_iter_range:
        for input_iter in range(input_iters):
            input_opt.zero_grad()
            pred = aux_model(x_pred)
            mse_loss = mse(pred, target)
            l2_loss = lambda_l2 * l2loss(x_pred)
            loss = mse_loss + l2_loss
            loss.backward(retain_graph=True)
            input_opt.step()
            if show_tqdm:
                main_iter_range.set_postfix({"mse loss": mse_loss.item(), "l2 loss": l2_loss.item(), "person": pearson_correlation(x_pred, x_true).item()})


        for model_iter in range(model_iters):
            model_opt.zero_grad()
            pred = aux_model(x_pred)
            mse_loss = mse(pred, target)
            loss = mse_loss
            loss.backward(retain_graph=True)
            model_opt.step()
            if show_tqdm:
                main_iter_range.set_postfix({"mse loss": mse_loss.item()})

    return x_pred.detach()

def l2loss(x):
    return (x ** 2).mean()



class USLinear(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.layer = nn.Linear(in_dim, out_dim, bias=False)


    def forward(self, x):
        x = self.layer(x)
        return x


