from core.lrm import LearnRiskModel
from core.build_up import building_up
import torch
from tqdm import tqdm
from core.risk_utils import *
import numpy as np
from matplotlib import pyplot as plt

def plot_normal_distribution(mean, std, x_range=(-8, 8), num_points=100):
    # 创建正态分布
    a = 1
    for m, s in zip(mean, std):
        normal_dist = torch.distributions.Normal(m, s)

        # 生成 x 范围
        x = torch.linspace(x_range[0], x_range[1], num_points)

        # 计算概率密度函数 (PDF)
        pdf = torch.exp(normal_dist.log_prob(x))
        q90_dist1 = normal_dist.icdf(torch.tensor(0.9)).item()
        # 绘制正态分布
        if a == 1:
            ma = 'high risk'
        else:
            ma = 'low risk'
        plt.plot(x.numpy(), pdf.numpy(), label=f"{ma}")
        plt.axvline(q90_dist1, linestyle='--', label=f'Var of {ma}')
        a = 0
    plt.xlabel("x")
    plt.ylabel("Probability Density")
    plt.title("Normal Distribution PDF")
    plt.legend()
    plt.grid()
    plt.show()

def var_loss(mu, sigma, matching, w, confidence=0.9, mode='var', constraint=False):

    if mode == 'cvar':
        var = conditional_value_at_rask(mu, sigma, matching, confidence)
        p = torch.sigmoid(sigma * var)

    else:

        var = value_at_rask(mu, sigma, confidence)
        p = torch.sigmoid(var)

    loss = -((matching * torch.log(p)) + ((1 - matching) * torch.log(1 - p)))
    loss = torch.mean(loss)

    if constraint:
        constraint_loss = 1 * (w.sum() - 1) ** 2
    else:
        constraint_loss = 0.

    return loss + constraint_loss


def train(mu_root, activate_root, risk_label_root, batch_size, epoch, lr):
    new_data, (in_dim, n_id) = building_up(mu_root, activate_root, risk_label_root, batch_size)

    model = LearnRiskModel(in_dim=in_dim, n_id=n_id).cuda()
    model.train()

    print('Parameters:', sum([p.numel() for p in model.parameters() if p.requires_grad]))

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for i in range(epoch):
        loss_ = []
        acc = 0
        al = 0
        with tqdm(new_data, desc=f'Epoch: {i + 1:04d}') as dbar:
            for m, a, y in dbar:
                m = m.float().cuda()
                y = y.float().cuda()
                a = a.float().cuda()
                optimizer.zero_grad()

                mu_bar, sigma = model(m, a)

                loss = var_loss(mu_bar, sigma, y, w=model.w, constraint=False)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
                optimizer.step()

                loss_.append(loss.item())

                pro = var2pro(mu_bar, sigma)
                mask = pro > 0.5
                pro[mask] = 1.
                pro[~mask] = 0.

                acc += torch.sum(pro == y)
                al += y.shape[0] * y.shape[1]

                dbar.set_postfix(loss=f'{np.mean(loss_):.4f}', acc=f'{acc.item() / al:.4f}')
        # print('mu: ', mu_bar, 'sigma: ', sigma)  # index:4 matching; index:0 mismatching.
        # if i % 100 == 0:
        #     plot_normal_distribution(mean=[mu_bar[0].detach().cpu(), mu_bar[4].detach().cpu()], std=[sigma[0].detach().cpu(), sigma[4].detach().cpu()])
        #     print(torch.sigmoid(standard_normal_quantile(mu_bar[0].detach().cpu(), sigma[0].detach().cpu())))
    torch.save(model.state_dict(), 'lrm_mmstar_new.pth')
    print('Done')

    return

if __name__ == '__main__':
    train(mu_root='/home/15t/fzy/code/raie/mu/mmstar_new',
          activate_root='/home/15t/fzy/code/raie/activate_vector/mmstar_new_train',
          risk_label_root='/home/15t/fzy/code/raie/risk_label/mmstar_new',
          batch_size=32,
          epoch=250,
          lr=1e-3)