
#给已有的权重进行加噪  加噪方法跟P2SAM一样
import os
import sys
import logging
import argparse
import random
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader,Subset
import torch.backends.cudnn as cudnn
from utils import test_single_volume
from importlib import import_module
from segment_anything import sam_model_registry
from load_LIDC_data import LIDC_IDRI,RandomGenerator
from utils import init_weights,init_weights_orthogonal_normal, l2_regularisation
import torch.nn.functional as F

from scipy.ndimage import zoom
from einops import repeat
import matplotlib.pyplot as plt
from torchvision import transforms
from tensorboardX import SummaryWriter
from ambiguous_sam_v2 import Ambiguous_Sam
import torch.nn.functional as F
from scipy.ndimage import zoom
from einops import repeat
import matplotlib.pyplot as plt
from torch.distributions import Normal, Independent, kl

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')


class Mask_Weights(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Parameter(torch.ones(5, 1, requires_grad=True) / 6)
def calculate_dice_loss(inputs, targets, num_masks = 10):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
    """
    inputs = inputs.sigmoid()
    # inputs = inputs.flatten(1)
    numerator = 2 * (inputs * targets).sum(-1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    return loss.sum() / num_masks

def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 10, alpha: float = 0.25, gamma: float = 2):
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
    Returns:
        Loss tensor
    """
    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    return loss.mean(1).sum() / num_masks

def show_mask(mask, ax, color):
    h, w = mask.shape[-2:]
    color = np.array(color + [0.5])  # 添加透明度值
    mask_image = np.zeros((h, w, 4))  # 创建一个RGBA图像
    for i in range(3):  # 应用颜色到掩码非零区域
        mask_image[:,:,i] = mask.squeeze() * color[i]
    mask_image[:,:,3] = (mask.squeeze() > 0) * color[3]  # 仅在掩码非零区域应用透明度
    ax.imshow(mask_image)
def show_box(box, ax, color):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor='none', lw=2))


def kl_divergence( posterior_latent_space,prior_latent_space,analytic=True, calculate_posterior=False, z_posterior=None):
    """
    Calculate the KL divergence between the posterior and prior KL(Q||P)
    analytic: calculate KL analytically or via sampling from the posterior
    calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
    """
    if analytic:
        #Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
        kl_div = kl.kl_divergence(posterior_latent_space,prior_latent_space)
    else:
        if calculate_posterior:
            z_posterior = posterior_latent_space.rsample()
        log_posterior_prob = posterior_latent_space.log_prob(z_posterior)
        log_prior_prob =prior_latent_space.log_prob(z_posterior)
        kl_div = log_posterior_prob - log_prior_prob
    return kl_div

device="cuda:3"
patch_size=[128, 128]
net=Ambiguous_Sam()
net.to(device)
mask_weights = Mask_Weights().cuda()
mask_weights.train()

db = LIDC_IDRI(dataset_location='/data/cxli/yuzhi/Ambiguous_SAM/LIDC/data/', transform=transforms.Compose([
  RandomGenerator(output_size=[128, 128])
]))
dataset_size = len(db)

# 生成数据集的索引
indices = list(range(dataset_size))

# 计算分割点以划分训练集、验证集和测试集
train_split = int(np.floor(0.6 * dataset_size))  # 60% 作为训练集
new_train_split = int(np.floor(0.5 * train_split)) 
validation_split = int(np.floor(0.8 * dataset_size))  # 接下来20% 作为验证集，剩余20% 作为测试集

# 分配训练集、验证集和测试集索引
train_indices = indices[:train_split]
# train_indices = indices[:500]
validation_indices = indices[train_split:validation_split]
test_indices = indices[validation_split:]
train_dataset = Subset(db, train_indices)
# validation_dataset = Subset(db, validation_indices)
test_dataset = Subset(db, test_indices)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=False)
# validation_loader = DataLoader(validation_dataset, batch_size=5, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
# 输出划分后的数据集大小，确保划分正确
print(f"Total dataset size: {dataset_size}")
print(f"Training set size: {len(train_indices)}")
print(f"Validation set size: {len(validation_indices)}")
print(f"Test set size: {len(test_indices)}")

print("Number of training/test patches:", (len(train_indices), len(test_indices)))
writer = SummaryWriter('tf-logs/train_onestage')
optimizer = torch.optim.Adam(net.parameters(),lr=1e-4, weight_decay=0)
max_epoch=101



for epoch_num in range(1,max_epoch):
    loss_epoch=0.0
    segloss_epoch=0.0
    print(epoch_num)
    for i_batch, sampled_batch in enumerate(train_loader):
        image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
        boxori_batch=sampled_batch['box_ori']
        image_batch_oc=sampled_batch['image_oc']
        image_batch_oc=image_batch_oc.cuda().to(device)
        box1024_batch=sampled_batch['box_1024'].cuda().to(device)
        boxshift_batch=sampled_batch['box_shift'].cuda().to(device)
        # print(box_batch.shape)
        # print(box1024_batch.shape)
        image_batch, label_batch = image_batch.cuda().to(device), label_batch.cuda().to(device)
        assert image_batch.max() <= 3, f'image_batch max: {image_batch.max()}'
        outputs = net.forward(image_batch, image_batch_oc, box1024_batch,boxshift_batch, label_batch)
        output_masks=outputs['masks']
        low_res_logits= outputs['low_res_logits'] 
        # print(output_masks.shape)
        logits_high = output_masks.cuda().to(device)
        # print(logits_high.shape)
        weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0).to(device)
        logits_high = logits_high * weights.unsqueeze(-1)
        # print(logits_high.shape)

        logits_high_res = logits_high.sum(1).unsqueeze(1)
        mask=logits_high_res>0

        #开始计算kl损失和分割损失
        kl1=torch.mean(kl_divergence(net.posterior_box_latent_space,net.prior_box_latent_space))
        kl2=torch.mean(kl_divergence(net.posterior_object_latent_space,net.prior_object_latent_space))
        # print("kl1",kl1)
        # print("kl2",kl2)
        cel = torch.nn.CrossEntropyLoss()
        # print(low_res_logits.shape,low_res_label_batch.shape)
        cel_loss=cel(logits_high, label_batch[:].long())
        reg_loss = l2_regularisation(net.prior_box)+l2_regularisation(net.posterior_box)+l2_regularisation(net.fcomb_box.layers)+ l2_regularisation(net.prior_object)+l2_regularisation(net.posterior_object)+l2_regularisation(net.fcomb_object.layers)
        # reg_loss = l2_regularisation(net.prior_object)+l2_regularisation(net.fcomb_object.layers)
        gt_mask=label_batch.unsqueeze(1)
        dice_loss = calculate_dice_loss(logits_high_res,gt_mask[:].long())
        # print(dice_loss)
        focal_loss = calculate_sigmoid_focal_loss(logits_high_res, gt_mask[:].float())
        seg_loss=cel_loss+dice_loss+focal_loss
        # print("seg_loss",seg_loss)
        # print("reg_loss",reg_loss)
        loss = seg_loss + 1e-5 * reg_loss+kl1+kl2
        # loss=seg_loss + 1e-5 * reg_loss
        segloss_epoch+=seg_loss.item()
        loss_epoch+=loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        if  epoch_num%5==0 and i_batch%5==0:
            output_masks_multi=output_masks
            pred_mask=output_masks_multi[0]>0.5
            labs = label_batch[0, ...].unsqueeze(0) * 50
            output_masks = torch.argmax(torch.softmax(output_masks, dim=1), dim=1, keepdim=True)
            image = image_batch[0, 0:1, :, :]
            image = (image - image.min()) / (image.max() - image.min())
            fig, ax = plt.subplots()
            ax.imshow(image.squeeze().cpu().numpy(), cmap='gray')  # 确保image是单通道的灰度图
            # 假设 box_batch 是这个batch的边框数据，需要从 sampled_batch 中获取
            box = boxori_batch[0].cpu().numpy()  # 取第一个图像的边框
            show_box(box, ax, color='red')
            ax.axis('off')  # 关闭坐标轴

            # 将matplotlib图像转换为TensorBoard可以接受的Tensor形式
            fig.canvas.draw()
            image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            image_from_plot = np.moveaxis(image_from_plot, 2, 0)  # HWC to CHW
            image_from_plot = torch.tensor(image_from_plot).unsqueeze(0) / 255.0  # Normalize to [0, 1]
            writer.add_image('train_onestage/epoch{}_batch{}_img'.format(epoch_num, i_batch), image_from_plot.squeeze())
            plt.close(fig)
            writer.add_image('train_onestage/epoch{}_batch{}_ori_pred'.format(epoch_num,i_batch), output_masks[0, ...] * 50)
            writer.add_image('train_onestage/epoch{}_batch{}_lab'.format(epoch_num,i_batch), labs)
            writer.add_image('train_onestage/epoch{}_batch{}_weighted_pred'.format(epoch_num,i_batch), mask[0,...])

            for i in range(6):
                writer.add_image('train_onestage/epoch{}_batch{}_pred'.format(epoch_num,i_batch), (pred_mask[i, ...] * 50).unsqueeze(0),global_step=i)
    writer.add_scalar('train_seg_loss', segloss_epoch / len(train_loader), global_step=epoch_num)
    writer.add_scalar('train_total_loss', loss_epoch / len(train_loader), global_step=epoch_num)
    print("seg_loss",segloss_epoch/len(train_loader))
    print(loss_epoch/len(train_loader))

    if epoch_num %5==0:
        file_name="ckpoint/onestage/allsamples_100epoch_1e-4_{}epoch.pth".format(epoch_num)
        torch.save(net.state_dict(), file_name)
        torch.save(weights, 'ckpoint/onestage_weight/allsamples_100epoch_1e-4_{}epoch.pt'.format(epoch_num))