import os
import sys
import logging
import argparse
import random
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader,Subset
import torch.backends.cudnn as cudnn
from utils import test_single_volume
from importlib import import_module
from segment_anything import sam_model_registry
from load_LIDC_data import LIDC_IDRI,RandomGenerator
from sam_lora_image_encoder import LoRA_Sam
from utils import init_weights,init_weights_orthogonal_normal, l2_regularisation
import torch.nn.functional as F
from torch.distributions import Normal, Independent
from tqdm import tqdm
from scipy.ndimage import zoom
from einops import repeat
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from tensorboardX import SummaryWriter
from torch.distributions import Normal, Independent, kl

#导入一些必须的类
class Encoder(nn.Module):
    """
    A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block convolutional layers,
    after each block a pooling operation is performed. And after each convolutional layer a non-linear (ReLU) activation function is applied.
    """
    def __init__(self, input_channels, num_filters, no_convs_per_block, initializers, padding=True,posterior=False,object=False):
        super(Encoder, self).__init__()
        self.contracting_path = nn.ModuleList()
        self.input_channels = input_channels
        self.num_filters = num_filters
        self.posterior=posterior
        self.object=object
        if self.posterior and self.object:#如果是后验网络并且是object
            self.input_channels += 1
            # print(1)
            #To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels.
            # self.input_channels += 3
        # if self.posterior==False and self.object:
        #     self.input_channels +=2
        # elif posterior and not object:#
        #     self.input_channels =272
        # elif not posterior and not object:#
        #     self.input_channels=264

        layers = []
        for i in range(len(self.num_filters)):
            """
            Determine input_dim and output_dim of conv layers in this block. The first layer is input x output,
            All the subsequent layers are output x output.
            """
            input_dim = self.input_channels if i == 0 else output_dim
            output_dim = num_filters[i]
            
            if i != 0:
                layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))
            
            layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=int(padding)))
            layers.append(nn.ReLU(inplace=True))

            for _ in range(no_convs_per_block-1):
                layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=int(padding)))
                layers.append(nn.ReLU(inplace=True))

        self.layers = nn.Sequential(*layers)

        self.layers.apply(init_weights)

    def forward(self, input):
        output = self.layers(input)
        return output



class AxisAlignedConvGaussian_box(nn.Module):#对box加噪，首先要将box_embedding和img_emb进行concate,box_embedding先经过MLP，然后和img_emb进行concate
    """
    A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.
    """
    def __init__(self, input_channels, num_filters, no_convs_per_block, latent_dim, initializers, posterior=False,object=False):
        super(AxisAlignedConvGaussian_box, self).__init__()
        self.input_channels = input_channels
        self.channel_axis = 1
        self.output_channels=8
        self.num_filters = num_filters
        self.no_convs_per_block = no_convs_per_block
        self.latent_dim = latent_dim
        self.posterior = posterior
        self.object=object
        if self.posterior:
            self.name = 'Posterior'
            self.input_feature=1024
        else:
            self.name = 'Prior'
            self.input_feature=512
        self.box_input_channel=264
        self.fc1 = nn.Linear(self.input_feature, 1024)  # 第一层
        self.fc2 = nn.Linear(1024, 512)             # 第二层
        self.fc3 = nn.Linear(512, self.output_channels * 64)  # 最后一层，输出的大小为 C*64
        self.box_conv = nn.Sequential(
            nn.Conv2d(self.box_input_channel, 128, kernel_size=3, padding=1),  # 8x8 -> 8x8
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),  # 8x8 -> 16x16
            nn.Conv2d(128, 64, kernel_size=3, padding=1),  # 16x16 -> 16x16
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),  # 16x16 -> 32x32
            nn.Conv2d(64, 32, kernel_size=3, padding=1),  # 32x32 -> 32x32
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),  # 32x32 -> 64x64
            nn.Conv2d(32, 16, kernel_size=3, padding=1),  # 64x64 -> 64x64
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),  # 64x64 -> 128x128
            nn.Conv2d(16, 1, kernel_size=3, padding=1)  # 128x128 -> 128x128
        )


        self.encoder = Encoder(self.input_channels, self.num_filters, self.no_convs_per_block, initializers,posterior=self.posterior, object=self.object)
        self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1,1), stride=1)
        self.show_img = 0
        self.show_seg = 0
        self.show_concat = 0
        self.show_enc = 0
        self.sum_input = 0

        nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu')
        nn.init.normal_(self.conv_layer.bias)

    def forward(self, input,boxemb_shift,boxemb_ori=None):
        #如果是后验网络 先将偏移box和不偏移box进行cat 
        if boxemb_ori is not None:
            boxemb_input = torch.cat((boxemb_shift, boxemb_ori), dim=1)
            # print(boxemb_input.shape)
        else:
            boxemb_input = boxemb_shift
        
        #将boxemb转换成跟imgemb一样的形式
        boxemb_input = boxemb_input.view(boxemb_input.size(0), -1)

        # 通过全连接层
        boxemb_input = F.relu(self.fc1(boxemb_input))
        boxemb_input = F.relu(self.fc2(boxemb_input))
        boxemb_input = self.fc3(boxemb_input)

        # 重塑输出至 (B, 8, 8, 8)
        boxemb_input = boxemb_input.view(boxemb_input.size(0), self.output_channels, 8, 8)
        # print(boxemb_input.shape)
        #将imgemb和boxemdcat
        input=torch.cat((input, boxemb_input), dim=1)
        input = input.to(torch.float32)
        input=self.box_conv(input)
        # print(input.shape)

        encoding = self.encoder(input)
        self.show_enc = encoding

        #We only want the mean of the resulting hxw image，计算b×2*lantentdim的均值
        encoding = torch.mean(encoding, dim=2, keepdim=True)
        encoding = torch.mean(encoding, dim=3, keepdim=True)

        #Convert encoding to 2 x latent dim and split up for mu and log_sigma
        mu_log_sigma = self.conv_layer(encoding)
        #将输入数据映射到潜在空间，并计算出潜在空间中每个样本的均值和标准差
        #We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1
        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)#取完均值再降维
        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)

        mu = mu_log_sigma[:,:self.latent_dim]#取出每个batch的mu
        log_sigma = mu_log_sigma[:,self.latent_dim:]#取出每个batch的sigma
        #This is a multivariate normal with diagonal covariance matrix sigma
        #https://github.com/pytorch/pytorch/pull/11178
        dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)),1)#dist 表示一个从多变量正态分布中采样的独立样本的分布。
        
        return dist


class Fcomb_box(nn.Module):#将mask和主网络的结果进行concat
    """
    A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space,
    and output of the UNet (the feature map) by concatenating them along their channel axis.
    """
    def __init__(self, num_filters, latent_dim, num_output_channels, num_classes, no_convs_fcomb, initializers, use_tile=True):
        super(Fcomb_box, self).__init__()
        self.num_channels = num_output_channels #output channels
        self.num_classes = num_classes
        self.channel_axis = 1
        self.spatial_axes = [1,2,3]
        self.num_filters = num_filters
        self.latent_dim = latent_dim
        self.use_tile = use_tile
        self.no_convs_fcomb = no_convs_fcomb 
        self.name = 'Fcomb'

        if self.use_tile:
            layers = []

            #Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer
            layers.append(nn.Conv2d(512, 256, kernel_size=1))
            layers.append(nn.ReLU(inplace=True))

            for _ in range(no_convs_fcomb-2):
                layers.append(nn.Conv2d(256, 256, kernel_size=1))
                layers.append(nn.ReLU(inplace=True))

            self.layers = nn.Sequential(*layers)

            self.last_layer = nn.Conv2d(256, 256, kernel_size=1)

            if initializers['w'] == 'orthogonal':
                self.layers.apply(init_weights_orthogonal_normal)
                self.last_layer.apply(init_weights_orthogonal_normal)
            else:
                self.layers.apply(init_weights)
                self.last_layer.apply(init_weights)

    def tile(self, a, dim, n_tile):#在指定的维度上复制张量的内容，以扩展张量的尺寸
        """
        This function is taken form PyTorch forum and mimics the behavior of tf.tile.
        Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
        """
        init_dim = a.size(dim)
        repeat_idx = [1] * a.dim()
        repeat_idx[dim] = n_tile
        a = a.repeat(*(repeat_idx))
        order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to('cuda:3')
        return torch.index_select(a, dim, order_index)

    def forward(self, feature_map, z):#使用的1*1卷积 不会对特征图大小进行改变，保证通道数不变就行
        """
        Z is batch_sizexlatent_dim and feature_map is batch_sizexno_channelsxHxW.
        So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)
        """
        if self.use_tile:
            # print(feature_map.shape)#torch.Size([1, 256, 8, 8])
            # print(z.shape)#torch.Size([1, 6])

            z = torch.unsqueeze(z,2)#
            z = torch.unsqueeze(z,2)#
            z = self.tile(z, 2, feature_map.shape[self.spatial_axes[1]])
            z = self.tile(z, 3, feature_map.shape[self.spatial_axes[2]])
            # print(z.shape)
            # z = torch.unsqueeze(z,3)
            # z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]])
            

            #Concatenate the feature map (output of the UNet) and the sample taken from the latent space
            #print(feature_map.shape)
            #print(z.shape)
            feature_map = torch.cat((feature_map, z), dim=1)
            output = self.layers(feature_map)
            output = self.last_layer(output)
            #print(output.shape)
            # print(output)
            return output     




class AxisAlignedConvGaussian_object(nn.Module):#对box加噪
    """
    A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.
    """
    def __init__(self, input_channels, num_filters, no_convs_per_block, latent_dim, initializers, posterior=False,object=False):
        super(AxisAlignedConvGaussian_object, self).__init__()
        self.input_channels = input_channels
        self.channel_axis = 1
        self.num_filters = num_filters
        self.no_convs_per_block = no_convs_per_block
        self.latent_dim = latent_dim
        self.posterior = posterior
        self.object=object
        if self.posterior:
            self.name = 'Posterior'
        else:
            self.name = 'Prior'
        self.encoder = Encoder(self.input_channels, self.num_filters, self.no_convs_per_block, initializers,posterior=self.posterior, object=self.object)
        self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1,1), stride=1)
        self.show_img = 0
        self.show_seg = 0
        self.show_concat = 0
        self.show_enc = 0
        self.sum_input = 0

        nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu')
        nn.init.normal_(self.conv_layer.bias)

    def forward(self, input,segm=None):
        if segm is not None:
            input = torch.cat((input, segm), dim=1)

        input = input.to(torch.float32)
        encoding = self.encoder(input)
        self.show_enc = encoding

        #We only want the mean of the resulting hxw image，计算b×2*lantentdim的均值
        encoding = torch.mean(encoding, dim=2, keepdim=True)
        encoding = torch.mean(encoding, dim=3, keepdim=True)

        #Convert encoding to 2 x latent dim and split up for mu and log_sigma
        mu_log_sigma = self.conv_layer(encoding)
        #将输入数据映射到潜在空间，并计算出潜在空间中每个样本的均值和标准差
        #We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1
        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)#取完均值再降维
        mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)

        mu = mu_log_sigma[:,:self.latent_dim]#取出每个batch的mu
        log_sigma = mu_log_sigma[:,self.latent_dim:]#取出每个batch的sigma
        #This is a multivariate normal with diagonal covariance matrix sigma
        #https://github.com/pytorch/pytorch/pull/11178
        dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)),1)#dist 表示一个从多变量正态分布中采样的独立样本的分布。
        
        return dist


class Fcomb_object(nn.Module):#将mask和主网络的结果进行concat
    """
    A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space,
    and output of the UNet (the feature map) by concatenating them along their channel axis.
    """
    def __init__(self, num_filters, latent_dim, num_output_channels, num_classes, no_convs_fcomb, initializers, use_tile=True):
        super(Fcomb_object, self).__init__()
        self.num_channels = num_output_channels #output channels
        self.num_classes = num_classes
        self.channel_axis = 1
        self.spatial_axes = [1,2,3]
        self.num_filters = num_filters
        self.latent_dim = latent_dim
        self.use_tile = use_tile
        self.no_convs_fcomb = no_convs_fcomb 
        self.name = 'Fcomb'

        if self.use_tile:
            layers = []

            #Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer
            layers.append(nn.Conv2d(512, 256, kernel_size=1))
            layers.append(nn.ReLU(inplace=True))

            for _ in range(no_convs_fcomb-2):
                layers.append(nn.Conv2d(256, 256, kernel_size=1))
                layers.append(nn.ReLU(inplace=True))

            self.layers = nn.Sequential(*layers)

            self.last_layer = nn.Conv2d(256, 256, kernel_size=1)

            if initializers['w'] == 'orthogonal':
                self.layers.apply(init_weights_orthogonal_normal)
                self.last_layer.apply(init_weights_orthogonal_normal)
            else:
                self.layers.apply(init_weights)
                self.last_layer.apply(init_weights)

    def tile(self, a, dim, n_tile):#在指定的维度上复制张量的内容，以扩展张量的尺寸
        """
        This function is taken form PyTorch forum and mimics the behavior of tf.tile.
        Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
        """
        init_dim = a.size(dim)
        repeat_idx = [1] * a.dim()
        repeat_idx[dim] = n_tile
        a = a.repeat(*(repeat_idx))
        order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to('cuda:3')
        return torch.index_select(a, dim, order_index)

    def forward(self, feature_map, z):#使用的1*1卷积 不会对特征图大小进行改变，保证通道数不变就行
        """
        Z is batch_sizexlatent_dim and feature_map is batch_sizexno_channelsxHxW.
        So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)
        """
        if self.use_tile:
            # print(feature_map.shape)#torch.Size([1, 256, 8, 8])
            # print(z.shape)#torch.Size([1, 6])

            z = torch.unsqueeze(z,2)#
            z = torch.unsqueeze(z,2)#
            z = self.tile(z, 2, feature_map.shape[self.spatial_axes[1]])
            z = self.tile(z, 3, feature_map.shape[self.spatial_axes[2]])
            # print(z.shape)
            # z = torch.unsqueeze(z,3)
            # z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]])
            

            #Concatenate the feature map (output of the UNet) and the sample taken from the latent space
            #print(feature_map.shape)
            #print(z.shape)
            feature_map = torch.cat((feature_map, z), dim=1)
            output = self.layers(feature_map)
            output = self.last_layer(output)
            #print(output.shape)
            # print(output)
            return output     


class Ambiguous_Sam(nn.Module):

    def __init__(self,lora_ckpt="/data/cxli/yuzhi/samed/SAMed-main/ckpoint/weight_ckpoint/500samples_noflip_100epoch.pth", device='cuda:3',input_channels=1, num_classes=6, img_size=128,num_filters=[32,64,128,192],latent_dim=256, no_convs_fcomb=4, beta=10.0):
        super(Ambiguous_Sam, self).__init__()
        self.device =device
        self.ckpt="sam_vit_b_01ec64.pth"
        self.lora_ckpt=lora_ckpt
        self.img_size=img_size
        self.input_channels = input_channels
        self.num_classes = num_classes
        self.num_filters = num_filters
        self.latent_dim = latent_dim
        self.no_convs_per_block = 3
        self.no_convs_fcomb = no_convs_fcomb
        self.initializers = {'w':'he_normal', 'b':'normal'}
        self.beta = beta
        self.z_prior_sample = 0
        self.sam,self.img_embedding_size = sam_model_registry["vit_b"](image_size=self.img_size,
                                            num_classes=self.num_classes,
                                            checkpoint=self.ckpt, pixel_mean=[0, 0, 0],
                                            pixel_std=[1, 1, 1])
        self.sam.to(device)
        self.lora_sam=LoRA_Sam(self.sam,4).to(device)
        # self.lora_sam.load_lora_parameters(self.lora_ckpt)
        self.prior_object = AxisAlignedConvGaussian_object(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers,posterior=False,object=True).to(device)
        self.prior_box = AxisAlignedConvGaussian_box(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers,posterior=False,object=False).to(device)
        self.posterior_object = AxisAlignedConvGaussian_object(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers,posterior=True,object=True).to(device)
        self.posterior_box = AxisAlignedConvGaussian_box(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers,posterior=True,object=False).to(device)
        self.fcomb_object = Fcomb_object(self.num_filters, self.latent_dim, self.input_channels, self.num_classes, self.no_convs_fcomb, {'w':'orthogonal', 'b':'normal'}, use_tile=True).to(device)
        self.fcomb_box= Fcomb_box(self.num_filters, self.latent_dim, self.input_channels, self.num_classes, self.no_convs_fcomb, {'w':'orthogonal', 'b':'normal'}, use_tile=True).to(device)


    
    def forward(self,batch_input,batch_input_ori,batch_boxori,batch_boxshift,batch_mask,input_size=128,train=True):
        #sam的预处理
        img_size=input_size
        input_images = self.lora_sam.sam.preprocess(batch_input)
        image_embeddings = self.lora_sam.sam.image_encoder(input_images)      
        sparse_embeddings_shift, dense_embeddings_shift = self.lora_sam.sam.prompt_encoder(
            points=None, boxes=batch_boxshift, masks=None
        )
        sparse_embeddings_ori, dense_embeddings_ori = self.lora_sam.sam.prompt_encoder(
            points=None, boxes=batch_boxori, masks=None
        )
        self.prior_box_latent_space = self.prior_box.forward(image_embeddings,sparse_embeddings_shift)
        self.prior_object_latent_space = self.prior_object.forward(batch_input_ori)
        if train:
            self.posterior_box_latent_space = self.posterior_box.forward(image_embeddings, sparse_embeddings_ori,sparse_embeddings_shift)
            self.posterior_object_latent_space = self.posterior_object.forward(batch_input_ori, batch_mask.unsqueeze(1))
            self.z_posterior_box = self.posterior_box_latent_space.rsample()
            self.z_posterior_object = self.posterior_object_latent_space.rsample()
            self.z_prior_box=self.prior_box_latent_space.rsample()
            self.z_prior_object=self.prior_object_latent_space.rsample()
            dense_embeddings_disturb=self.fcomb_box.forward(dense_embeddings_shift,self.z_posterior_box)
            image_embeddings_disturb=self.fcomb_object.forward(image_embeddings,self.z_posterior_object)
            

            
        else:
            self.z_prior_box=self.prior_box_latent_space.sample()
            self.z_prior_object=self.prior_object_latent_space.sample()
            dense_embeddings_disturb=self.fcomb_box.forward(dense_embeddings_shift,self.z_prior_box)
            image_embeddings_disturb=self.fcomb_object.forward(image_embeddings,self.z_prior_object)



        low_res_masks, iou_predictions = self.lora_sam.sam.mask_decoder(
            image_embeddings=image_embeddings_disturb,
            image_pe=self.lora_sam.sam.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings_shift,
            dense_prompt_embeddings=dense_embeddings_disturb,
            multimask_output=True
        )
        # print("low_res_masks:",low_res_masks.shape)
        masks = self.lora_sam.sam.postprocess_masks(
            low_res_masks,
            input_size=(img_size,img_size ),
            original_size=(128, 128)
        )

        outputs = {
            'masks': masks,
            'iou_predictions': iou_predictions,
            'low_res_logits': low_res_masks
        }
        return outputs


