import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
from torch_geometric.utils import to_undirected
from sklearn.preprocessing import MinMaxScaler,StandardScaler,scale
from torch_geometric.data import Data
from torch_geometric.nn import GATConv, GATv2Conv, GCNConv, SAGEConv, ResGatedGraphConv, GINConv, TransformerConv, RGATConv, TAGConv, VGAE, InnerProductDecoder
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.utils import negative_sampling
from math import sqrt

EPS = 1e-15
MAX_LOGSTD = 10

class ActivateLayer(nn.Module):
    def __init__(self, activation_func):
        super(ActivateLayer, self).__init__()
        self.activation_func = activation_func

    def forward(self, x):
        x = self.activation_func(x)
        return x
    
class GraphConvLayer(nn.Module):
    def __init__(self, conv_func):#in_channels, out_channels
        super(GraphConvLayer, self).__init__()
        #self.conv = conv_func(in_channels, out_channels)
        self.conv = conv_func

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        return x
    
class CalculateAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V, mask=None):
        with torch.no_grad():
            x1 = torch.matmul(Q,torch.transpose(K, -1, -2))
            # use mask
            if mask is not None:
                x1 = x1.masked_fill_(mask, -1e9)
            x1.div_(sqrt(Q.size(-1)))
            x2 = torch.softmax(x1, dim=-1); del x1;
            x3 = torch.matmul(x2,V); del x2;  
        return x3
    
class Multi_CrossAttention(nn.Module):
    def __init__(self,embed_dim,all_head_dim,num_heads):
        super().__init__()
        self.embed_dim    = embed_dim      
        self.all_head_dim  = all_head_dim     
        self.num_heads      = num_heads         
        self.h_size         = all_head_dim // num_heads

        assert all_head_dim % num_heads == 0

        # W_Q,W_K,W_V (hidden_size,all_head_size)
        self.linear_q = nn.Linear(embed_dim, all_head_dim , bias=False)
        self.linear_k = nn.Linear(embed_dim, all_head_dim , bias=False)
        self.linear_v = nn.Linear(embed_dim, all_head_dim , bias=False)
        self.linear_output = nn.Linear(all_head_dim, embed_dim)

        # normalization
        self.norm = sqrt(all_head_dim)

    def print(self):
        print(self.embed_dim,self.all_head_dim)
        print(self.linear_k,self.linear_q,self.linear_v)
    
    def forward(self,q,k,v,attention_mask=None):
        batch_size = q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)

        # q_s: [batch_size, num_heads, seq_length, h_size]
        q_s = self.linear_q(q).view(batch_size, -1, self.num_heads, self.h_size).transpose(1,2)

        # k_s: [batch_size, num_heads, seq_length, h_size]
        k_s = self.linear_k(k).view(batch_size, -1, self.num_heads, self.h_size).transpose(1,2)

        # v_s: [batch_size, num_heads, seq_length, h_size]
        v_s = self.linear_v(v).view(batch_size, -1, self.num_heads, self.h_size).transpose(1,2)

        if attention_mask is not None:
            attention_mask = attention_mask.eq(0)

        attention = CalculateAttention()(q_s,k_s,v_s,attention_mask)
        # attention : [batch_size , seq_length , num_heads * h_size]
        attention = attention.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.h_size)
        
        # output : [batch_size , seq_length , hidden_size]
        output = self.linear_output(attention)

        return output, attention

class Attention_module(nn.Module):
    def __init__(self, embed_dim, num_heads, brief_att = True, all_head_dim=256, mode='cross', extract_att_weight=True):
        super(Attention_module, self).__init__()
        self.mode = mode
        self.extract_att_weight = extract_att_weight
            
        if self.mode in ['cross','rna','img']:
            print('Attention_module %s.\n number of heads in Attention_module: %s' % (mode,num_heads))
            if brief_att:
                self.attention = Multi_CrossAttention(embed_dim,all_head_dim,num_heads)
            else:
                self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        else:
            print('NO Attention modal!')

    def forward(self, img_tensor, rna_tensor):
        ## unsqueeze: batch_size, seq_length, embed_dim
        img_tensor = img_tensor.unsqueeze(0)
        rna_tensor = rna_tensor.unsqueeze(0)
        
        if self.mode == 'cross':
            attended_rna, attn_weights_rna = self.attention(rna_tensor, img_tensor, img_tensor)
            rna_emb = rna_tensor+attended_rna
            
            attended_img, attn_weights_img = self.attention(img_tensor, rna_tensor, rna_tensor)
            img_emb = img_tensor+attended_img
            
            attn_weight = (attn_weights_rna,attn_weights_img)
        
        elif self.mode == 'rna':
            attended_rna, attn_weights_rna = self.attention(rna_tensor, img_tensor, img_tensor) 
            rna_emb = rna_tensor+attended_rna  
            img_emb = img_tensor
            attn_weight = attn_weights_rna
        
        elif self.mode=='img':
            attended_img, attn_weights_img = self.attention(img_tensor, rna_tensor, rna_tensor)
            img_emb = img_tensor+attended_img   
            rna_emb = rna_tensor
            attn_weight = attn_weights_img

        # else:
        #     img_emb = img_tensor
        #     rna_emb = rna_tensor
        #     attn_weight = 'NO Attention modal!'
            
        rna_emb = torch.squeeze(rna_emb,0)
        img_emb = torch.squeeze(img_emb,0)
        return img_emb,rna_emb,attn_weight

class Encoder(nn.Module):
    def __init__(self, input_size, hidden_sizes, gnn_type = 'GCN', activation_func = nn.ReLU()):
        super(Encoder, self).__init__()
        self.stack_net = self._make_encoder(input_size, hidden_sizes, gnn_type, activation_func)

    def _build_layer(self, conv_func, activation_func, drop_p = 0):
        activation_layer = ActivateLayer(activation_func)
        conv_layer = GraphConvLayer(conv_func)

        layers = []
        layers.append(conv_layer)
        layers.append(activation_layer)
        if drop_p > 0:
            layers.append(nn.Dropout(drop_p))

        return nn.ModuleList(layers)
    
    def _make_layer(self, block, in_size, growth_rate, num_layers, droprate):
        layers = []
        for i in range(num_layers):
            layers.append(block(in_size, in_size-i*growth_rate, droprate))
        return nn.Sequential(*layers)
    
    def _make_encoder(self, input_size, hidden_sizes, gnn_type, activation_func):
        encoder = nn.ModuleList()
        # Define activation function
        #activation_func = activation_func
        
        # Define the convolution function
        for i in range(len(hidden_sizes)):
            if gnn_type == 'GCN': 
                if i == 0:
                    # First layer, input size is the initial input size
                    conv_func = GCNConv(input_size, hidden_sizes[i])#.double()
                else:
                    # Subsequent layers, input size is the output size of the previous layer
                    conv_func = GCNConv(hidden_sizes[i-1], hidden_sizes[i])#.double()
                    
            elif gnn_type == 'GAT': 
                if i == 0:
                    conv_func = GATConv(input_size, hidden_sizes[i])#.double()
                else:
                    conv_func = GATConv(hidden_sizes[i-1], hidden_sizes[i])#.double()
                    
            elif gnn_type == 'SAGE': 
                if i == 0:
                    conv_func = SAGEConv(input_size, hidden_sizes[i])#.double()
                else:
                    conv_func = SAGEConv(hidden_sizes[i-1], hidden_sizes[i])#.double()
                    
            elif gnn_type == 'ResGate': 
                if i == 0:
                    conv_func = ResGatedGraphConv(input_size, hidden_sizes[i])#.double()
                else:
                    conv_func = ResGatedGraphConv(hidden_sizes[i-1], hidden_sizes[i])#.double()
                    
            elif gnn_type == 'GIN': 
                if i == 0:
                    conv_func = GINConv(input_size, hidden_sizes[i])#.double()
                else:
                    conv_func = GINConv(hidden_sizes[i-1], hidden_sizes[i])#.double()
                    
            elif gnn_type == 'Transformer': 
                if i == 0:
                    conv_func = TransformerConv(input_size, hidden_sizes[i])#.double()
                else:
                    conv_func = TransformerConv(hidden_sizes[i-1], hidden_sizes[i])#.double()
                    
            elif gnn_type == 'RGAT': 
                if i == 0:
                    conv_func = RGATConv(input_size, hidden_sizes[i])#.double()
                else:
                    conv_func = RGATConv(hidden_sizes[i-1], hidden_sizes[i])#.double()
                    
            elif gnn_type == 'TAG': 
                if i == 0:
                    conv_func = TAGConv(input_size, hidden_sizes[i])#.double()
                else:
                    conv_func = TAGConv(hidden_sizes[i-1], hidden_sizes[i])#.double()

            encoder.add_module(f'encoder_L{i}', 
                    self._build_layer(conv_func,activation_func,drop_p = 0))

        return encoder 
    
    def forward(self, x, edge_index):
        for layer in self.stack_net: 
            x = layer[0](x, edge_index)
            x = layer[1](x)
        return x
    
    # def _make_layer(self, input_size, hidden_size, conv_type, num_layers):
    #     layers = []
    #     layers.append(GraphConvLayer(input_size, hidden_size, conv_type))
    #     for _ in range(num_layers - 1):
    #         layers.append(GraphConvLayer(hidden_size, hidden_size, conv_type))
    #     return nn.ModuleList(layers)
    
class Decoder(Encoder):
    def __init__(self, output_size, hidden_sizes, gnn_type = 'GCN', activation_func = nn.ReLU()):
        super().__init__(output_size, hidden_sizes)#在用super继承的时候参数只能加不能减
        self.stack_net = self._make_decoder(output_size, hidden_sizes, gnn_type, activation_func)
        
    def _make_decoder(self, output_size, hidden_sizes, gnn_type, activation_func): 
        decoder = nn.ModuleList()
        
        # Define activation function
        # activation_func = activation_func
        
        # Define the convolution function
        for i in range(len(hidden_sizes)):
            if gnn_type == 'GCN': 
                if i < len(hidden_sizes)-1:
                    # First layer, input size is the initial input size
                    conv_func = GCNConv(hidden_sizes[i], hidden_sizes[i+1])#.double()
                else:
                    # Subsequent layers, input size is the output size of the previous layer
                    conv_func = GCNConv(hidden_sizes[i], output_size)#.double()
                    
            elif gnn_type == 'GAT': 
                if i < len(hidden_sizes)-1:
                    conv_func = GATConv(hidden_sizes[i], hidden_sizes[i+1])#.double()
                else:
                    conv_func = GATConv(hidden_sizes[i], output_size)#.double()
                    
            elif gnn_type == 'SAGE': 
                if i < len(hidden_sizes)-1:
                    conv_func = SAGEConv(hidden_sizes[i], hidden_sizes[i+1])#.double()
                else:
                    conv_func = SAGEConv(hidden_sizes[i], output_size)#.double()
            
            elif gnn_type == 'ResGated': 
                if i < len(hidden_sizes)-1:
                    conv_func = ResGatedGraphConv(hidden_sizes[i], hidden_sizes[i+1])#.double()
                else:
                    conv_func = ResGatedGraphConv(hidden_sizes[i], output_size)#.double()
                    
            elif gnn_type == 'GIN': 
                if i < len(hidden_sizes)-1:
                    conv_func = GINConv(hidden_sizes[i], hidden_sizes[i+1])#.double()
                else:
                    conv_func = GINConv(hidden_sizes[i], output_size)#.double()
                    
            elif gnn_type == 'Transformer': 
                if i < len(hidden_sizes)-1:
                    conv_func = TransformerConv(hidden_sizes[i], hidden_sizes[i+1])#.double()
                else:
                    conv_func = TransformerConv(hidden_sizes[i], output_size)#.double()
                    
            elif gnn_type == 'RGAT': 
                if i < len(hidden_sizes)-1:
                    conv_func = RGATConv(hidden_sizes[i], hidden_sizes[i+1])#.double()
                else:
                    conv_func = RGATConv(hidden_sizes[i], output_size)#.double()
                    
            elif gnn_type == 'TAG': 
                if i < len(hidden_sizes)-1:
                    conv_func = TAGConv(hidden_sizes[i], hidden_sizes[i+1])#.double()
                else:
                    conv_func = TAGConv(hidden_sizes[i], output_size)#.double()

            decoder.add_module(f'decoder_L{i}', 
                    self._build_layer(conv_func,activation_func,drop_p = 0))

        return decoder 
    
class CustomVAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, gnn_type = 'GAT', decoder = None, activation_func = nn.ReLU()):
        super(CustomVAE, self).__init__()
        self.encoder = Encoder(input_size, hidden_size, gnn_type , activation_func)
        # self.fc_mean = GCNConv(hidden_size[-1], latent_size)
        # self.fc_logvar = GCNConv(hidden_size[-1], latent_size)
        
        if gnn_type == 'GCN':
            self.fc_mean = GCNConv(hidden_size[-1], latent_size)
            self.fc_logvar = GCNConv(hidden_size[-1], latent_size)
            
        elif gnn_type == 'GAT':
            self.fc_mean = GATConv(hidden_size[-1], latent_size)
            self.fc_logvar = GATConv(hidden_size[-1], latent_size)
            
        elif gnn_type == 'SAGE':
            self.fc_mean = SAGEConv(hidden_size[-1], latent_size)
            self.fc_logvar = SAGEConv(hidden_size[-1], latent_size)
            
        self.decoder = InnerProductDecoder() if decoder is None else decoder
            
    def encode(self, x, edge_index):
        encoded = self.encoder(x,edge_index)
        mean = self.fc_mean(encoded, edge_index)
        logvar = self.fc_logvar(encoded, edge_index)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mean + eps * std
        return z

    def decode(self, *args, **kwargs):
        #decoded = self.decoder(z, edge_index)
        return self.decoder(*args, **kwargs)

    def forward(self, x, edge_index):
        mean, logvar = self.encode(x, edge_index)
        z = self.reparameterize(mean, logvar)
        try:
            x_hat = self.decode(z.long(), edge_index.long())
        except:
            x_hat = self.decode(z, edge_index)
        return z, mean, logvar, x_hat
    
class FinalModal(nn.Module):
    def __init__(self, single_model_size, hidden_size, latent_size, num_heads=1, brief_att = True, attn_mode='cross', gnn_type = 'GCN', img_weight = 1, rna_weight = 1, decoder = None):
        super(FinalModal, self).__init__()
        self.img_weight = img_weight
        self.rna_weight = rna_weight
        self.attn_mode = attn_mode
        #self.atten = Attention_module(single_model_size*num_heads, num_heads, brief_att = brief_att, mode=attn_mode, all_head_dim=256, extract_att_weight=True)
        self.atten = Attention_module(single_model_size, num_heads, brief_att = brief_att, mode=attn_mode, all_head_dim=256, extract_att_weight=True)
        self.vae = CustomVAE(2*single_model_size, hidden_size, latent_size, gnn_type = gnn_type, decoder=decoder)
         
    def forward(self, img_tensor, rna_tensor, edge_index):
        if self.attn_mode in ['cross','rna','img']:
            img_emb,rna_emb,attn_weight = self.atten(img_tensor, rna_tensor)
            concat_emb = torch.cat((self.img_weight*img_emb,self.rna_weight*rna_emb),1)
        else:
            img_emb = img_tensor
            rna_emb = rna_tensor
            concat_emb = torch.cat((self.img_weight*img_tensor, self.rna_weight*rna_tensor),1)
            attn_weight = 'No attention model for modalitys fusion'
        
        z, mean, logvar, x_hat = self.vae(concat_emb,edge_index)
        return z, mean, logvar, x_hat, img_emb, rna_emb, attn_weight
       
    def innerproduct_loss(self, z, pos_edge_index, neg_edge_index=None):
        pos_loss = -torch.log(
            # self.vae.decoder(z, pos_edge_index, sigmoid=True) + EPS).mean()
            self.vae.decoder(z, pos_edge_index) + EPS).mean()

        # remove self-loops 
        # pos_edge_index, _ = remove_self_loops(pos_edge_index)
        # pos_edge_index, _ = add_self_loops(pos_edge_index)
        if neg_edge_index is None:
            neg_edge_index = negative_sampling(pos_edge_index, z.size(0))
        neg_loss = -torch.log(1 -
                              # self.vae.decoder(z, neg_edge_index, sigmoid=True) +
                              self.vae.decoder(z, neg_edge_index) +
                              EPS).mean()
        return pos_loss + neg_loss
    
    def kl_loss(self, mu = None,logstd = None) -> Tensor:
        mu = self.__mu__ if mu is None else mu
        logstd = self.__logstd__ if logstd is None else logstd.clamp(
            max=MAX_LOGSTD)
        return -0.5 * torch.mean(
            torch.sum(1 + 2 * logstd - mu**2 - logstd.exp()**2, dim=1))
    
    def recon_loss(self, input, output):
        reconstruction_loss = F.mse_loss(output, input, reduction='mean')
        return reconstruction_loss
    
    def print_networks(self, verbose=True):
        """Print the total number of parameters in the network and (if verbose) network architecture

        Parameters:
            verbose (bool) -- if verbose: print the network architecture
        """
        print('---------- Networks initialized -------------')
        for i,layer in enumerate(self.named_children()): 
            name = layer[0]
            if isinstance(name, str):
                net = getattr(self,name)
                num_params = 0
            for param in net.parameters():
                num_params += param.numel()
            if verbose:
                print(net)
            print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
        print('-----------------------------------------------')
        
    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

def test_modal():
    num_nodes = 10
    num_features = 50

    img_tensor = torch.randn(num_nodes, num_features)
    rna_tensor = torch.randn(num_nodes, num_features)

    edge_index = torch.randint(num_nodes, (2, num_nodes * 2))

    edge_index = to_undirected(edge_index)

    decoder = Decoder(output_size=50,hidden_sizes=[10,16,32])
    fimodal = FinalModal(50,[32,16],10,decoder=decoder)
    print(fimodal)
    z, mean, logvar, x_hat, img_emb, rna_emb, attn_weight = fimodal(img_tensor, rna_tensor, edge_index)
    print(z.shape, attn_weight)
    
    inner_loss = fimodal.innerproduct_loss(z,edge_index)
    #recon_loss = fimodal.recon_loss(x_hat,img_emb)
    kl_loss = fimodal.kl_loss(mean, logvar)
    print(inner_loss,kl_loss)

if __name__ == '__main__':
    test_modal()


import torch
print('torch version: ',torch.__version__)
flag = torch.cuda.is_available()
if flag:
    print("CUDA is available")
else:
    print("CUDA NOT available")

ngpu= 1
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print("The driver is：",device)
print("GPU： ",torch.cuda.get_device_name(0))


import os 
import argparse
import matplotlib.pyplot as plt
import scanpy as sc
#import cv2
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torchvision
from PIL import Image
from torch.autograd import Variable
from torchvision import models
from tqdm import tqdm
import anndata as ad
from matplotlib import pyplot as plt
from sklearn import metrics
from torchvision import transforms
from torchvision import *
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import warnings
import sys
from torch.utils.data import Dataset
from utils import find_res_binary,plot_spatial
warnings.filterwarnings('ignore')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device object
print(device)

Image.MAX_IMAGE_PIXELS = None

###### Version and Date
PROG_VERSION = '1.0'
PROG_DATE = '2023-08-31'

###### Usage
usage = '''

Version %s  by Luo Bingying  %s

Usage: %s -a <adata_file> -i <gem_file> -o <output_dir> [...]

''' % (PROG_VERSION, PROG_DATE, os.path.basename(sys.argv[0]))


def add_img2adata(img,adata,bin_size = 100, library_id = "cancer",spatial_key = "spatial"):
    adata.var_names_make_unique()
    adata.uns[spatial_key] = {library_id: {}}
    adata.uns[spatial_key][library_id]["images"] = {}
    img_arr =  np.array(img)
    adata.uns[spatial_key][library_id]["images"] = {"hires": img_arr}
    adata.uns[spatial_key][library_id]["scalefactors"] = {"tissue_hires_scalef": 1, "spot_diameter_fullres": 100}
    adata.obsm['spatial'][:,0] = adata.obsm['spatial'][:,0]-adata.obsm['spatial'][:,0].min()+bin_size/2
    adata.obsm['spatial'][:,1] = adata.obsm['spatial'][:,1]-adata.obsm['spatial'][:,1].min()+bin_size/2
    adata.obs['img_x'] = adata.obsm['spatial'][:,0]
    adata.obs['img_y'] = adata.obsm['spatial'][:,1]
    adata.obs['cell_id'] = adata.obs.index
    split_data_1=adata.obs['cell_id'].astype('str').str.split('_',expand=True)
    split_data_1.columns=['array_row','array_col']
    adata.obs=adata.obs.join(split_data_1)
    adata.uns["spatial"][library_id]["use_quality"] = 'hires'
    return adata

def image_crop(image,adata,slide_gem=False,save_path='./',crop_size=128,verbose=False,):
    tile_names = []

    with tqdm(total=len(adata),
              desc="Tiling image",
              bar_format="{l_bar}{bar} [ time left: {remaining} ]") as pbar:
        
        if slide_gem == False:
            crop_coord = zip(adata.obs["img_x"], adata.obs["img_y"])
        else:
            crop_coord = zip(adata.obs["x"], adata.obs["y"])
        
        for img_x, img_y in crop_coord:
            
            tile_name = str(img_x) + "-" + str(img_y) + "-" + str(crop_size)
            tile_path = os.path.join(save_path,tile_name+'.tiff')
            #print(tile_path)
            tile_names.append(str(tile_path))
            
            try:
                x1 = img_x - crop_size / 2 
                x2 = img_x + crop_size / 2 
                y1 = img_y - crop_size / 2 
                y2 = img_y + crop_size / 2 
                # tile = image[int(y1):int(y2),int(x1):int(x2)]#[y1:y2, x1:x2]
                # cv2.imwrite(tile_path, tile)
                box = (int(x1), int(y1), int(x2), int(y2))# (left, upper, right, lower)
                tile = image.crop(box)  
                tile.save(tile_path)
            except:
                print(tile_path," error, generate black image")
                # tile = np.zeros((crop_size, crop_size, 3), np.uint8)
                # tile[:] = [49,0,0]
                # cv2.imwrite(tile_path, tile)  
                tile = Image.new("RGB", (crop_size, crop_size), (255, 255, 255))
                tile.save(tile_path)
            
            if verbose:
                print(
                    "generate tile at location ({}, {})".format(
                        str(img_x), str(img_y)))
            pbar.update(1)
        adata.obs["slices_path"] = tile_names
    return adata

def plot_spatial_with_img(adata,save_path,file):
    scale = (adata.obsm['spatial'][:,0].max()-adata.obsm['spatial'][:,0].min())/(adata.obsm['spatial'][:,1].max()-adata.obsm['spatial'][:,1].min())
    plt.rcParams["figure.figsize"] = (8*scale,8)
    #sc.pl.embedding(adata, basis="spatial", title='VGAE_domain',color=["leiden"],size=20, show=False)
    sc.pl.spatial(adata, img_key="hires", color="leiden",basis="spatial", size=1.5,alpha_img=0.4)#,crop_coord=[0,11550,0,12600]
    plt.savefig(os.path.join(save_path,file), bbox_inches='tight')
    plt.show()
    return None
    
def save_data(adata,save_path,label_txt_file,adata_file):
    label_txt = adata.obs[['slices_path','leiden']]
    label_txt.to_csv(os.path.join(save_path,label_txt_file),sep = ' ', header=0, index= False)
    adata.write(os.path.join(save_path,adata_file))
    
class MyDataset(Dataset): 
    def __init__(self, txt_path, transform=None, target_transform=None):
        fh = open(txt_path, 'r') 
        imgs = []  
        for line in fh:
            line = line.rstrip() 
            words = line.split() 
            imgs.append((words[0], int(words[1])))

        self.imgs = imgs       
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        # img = Image.open(fn).convert('RGB')    
        img = Image.open(fn) 

        if self.transform is not None:
            img = self.transform(img)   

        return img, label

    def __len__(self):
        # assert len(self.image_list) == len(self.label_list)
        return len(self.imgs)   
    
def load_model(model_name = 'resnet50'):
    current_path = os.getcwd()
    if modal_name == 'resnet50':
        model_com = models.resnet50(pretrained=False)
        if 'StereoMMv1' in current_path:
            pth_path = os.path.join(current_path, "torch_pths/resnet50-19c8e357.pth")
        else:
            pth_path = os.path.join(current_path, "StereoMMv1/torch_pths/resnet50-19c8e357.pth")
        model_com.load_state_dict(torch.load(pth_path))
        num_features = model_com.fc.in_features
        ### strip the last layer
        feature_extractor = torch.nn.Sequential(*list(model_com.children())[:-1])
    elif modal_name == 'CHIEF':
        import timm
        feature_extractor = timm.create_model('swin_tiny_patch4_window7_224', embed_layer=ConvStem, pretrained=False)
        feature_extractor.head = nn.Identity()
        if 'StereoMMv1' in current_path:
            pth_path = os.path.join(current_path, "torch_pths/CHIEF_CTransPath.pth")
        else:
            pth_path = os.path.join(current_path, "StereoMMv1/torch_pths/CHIEF_CTransPath.pth")
        td = torch.load(pth_path)
        feature_extractor.load_state_dict(td['model'], strict=True)

    feature_extractor.to(device)
    feature_extractor.eval()
    return(feature_extractor)

def feature_extractor(model,dataset):
    feat_outputs = []
    model.eval()
    # Disable gradient calculation to improve inference speed
    with torch.no_grad():
        with tqdm(total=len(dataset),
              desc="calculate image feature",
              bar_format="{l_bar}{bar} [ time left: {remaining} ]") as pbar:
            # Iterate over the test set samples
            for i in range(len(dataset)):
                # Get the i-th sample from the test set
                inputs, target = dataset[i][0], torch.tensor(dataset[i][1])

                # Move the data to the device (e.g. GPU)
                inputs, target = inputs.to(device), target.to(device)

                # Compute the model prediction for the sample
                output = model(inputs.unsqueeze(0))
                #_, predicted = torch.max(output.data, 1)
                # Store the FC layer output
                output = output.data.cpu().numpy().ravel()
                feat_outputs.append(output)

                pbar.update(1)
    return feat_outputs
    
    
def extract_img_feat(tile_file,extractor):
    transforms_val = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    dataset = MyDataset(tile_file, transforms_val)
    
    # if extractor=='feature_extractor':
    #     model = feature_extractor
    # else:
    #     model = model_com
        
    feat_outputs = feature_extractor(model=extractor,dataset=dataset)
    feat_outputs = pd.DataFrame(feat_outputs)
    return feat_outputs
    
                
def save_img_feat(data,save_path,feat_file):
    print('shape of image_feature：',data.shape)
    data.to_pickle(os.path.join(save_path,feat_file))
    
def generate_adata(X,spatial=None):
    # adata = ad.AnnData(X,obs=obs)
    # adata.obsm['spatial'] = np.array(obs.loc[:,["img_x","img_y"]])
    adata = ad.AnnData(X)
    adata.obsm['spatial'] = np.array(spatial)
    sc.pp.neighbors(adata,use_rep='X',n_neighbors=20, n_pcs=30)
    return adata
  
def choose_res(adata,res_range,method = 'num_cluster',cluster_num = 10,criterion = 'CH_score'):
    if method == 'score':
        ## choose res by best cluster scores
        #CH_scores = DB_scores = sil_scores = []
        res_outs = []

        for res in res_range:
            sc.tl.leiden(adata, resolution=res)#,key_added=str("res"+str(res))
            labels=vgae.obs['leiden']
            CH_score = metrics.calinski_harabasz_score(emb, labels)
            DB_score = metrics.davies_bouldin_score(emb, labels)
            sil_score = metrics.silhouette_score(emb, labels)
            num = str(len(vgae.obs.leiden.unique()))
            print(str(str(res)+"_"+num),":",DB_score,CH_score,sil_score)
            res_out = [res,num,DB_score,CH_score,sil_score]
            res_outs.append(res_out)

        res_df = pd.DataFrame(res_outs,columns=['res','DB_score','CH_score','sil_score'])
        criterion = criterion
        if criterion in ['CH_score','sil_score']:
            best_res = res_df.iloc[res_df[criterion].idxmax(),0]
            print('best res choosen by',criterion,':',str(best_res))
        elif criterion == 'DB_score':
            best_res = res_df.iloc[res_df[criterion].idxmin(),0]
            print('best res choosen by',criterion,':',str(best_res))
        return res
    
    if method == 'num_cluster':
        ## choose res by number of cluster
        cluster_num = cluster_num
        found = False

        for res in res_range:
            sc.tl.leiden(adata, resolution=res)#,key_added=str("res"+str(res))
            num = len(adata.obs.leiden.unique())
            print(str(res),":", num)
            if num == cluster_num:
                print('res',str(res),'reach number of cluser',str(cluster_num))
                found = True
                return res

        if not found:
            print("No suitable resolution found")      

def definite_res(adata,res,save_path,plot_file='choose_res_spatial_plot.png',title=None):
    sc.tl.leiden(adata, resolution=res)
    print(adata.obs.leiden.unique())
    scale = (adata.obsm['spatial'][:,0].max()-adata.obsm['spatial'][:,0].min())/(adata.obsm['spatial'][:,1].max()-adata.obsm['spatial'][:,1].min())
    plt.rcParams["figure.figsize"] = (8*scale,8)
    sc.pl.embedding(adata, basis="spatial", color=["leiden"],title = title,size=20, show=False)
    plt.savefig(os.path.join(save_path,plot_file))
    plt.show()
    return adata

def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def main():
    import argparse
    ArgParser = argparse.ArgumentParser(usage=usage,description='Process some integers.')
    ArgParser.add_argument("--version", action="version", version=PROG_VERSION)

    ArgParser.add_argument("-a", "--adata", action="store", dest="input", required=True, type=str, help="path of adata file")
    ArgParser.add_argument("-i", "--image", action="store", dest="image", required=True, type=str, help="path of H&E image file")
    ArgParser.add_argument("-o", "--output", action="store", dest="output", required=True, type=str, help="output folder")
    ArgParser.add_argument("-m", "--model", action="store", dest="model_name", required=True, type=str, help="feature ectract model")
    ArgParser.add_argument("-b", "--bin_size", action="store", dest="bin_size", required=True, type=int, help="bin size of adata")
    ArgParser.add_argument("-c", "--crop_size", action="store", dest="crop_size", required=True, type=int, help="crop size of image")
    ArgParser.add_argument("-n", "--num_cluster", action="store", dest="num_cluster", required=False, default=10, type=int, help="number of cluster")
    ArgParser.add_argument("-g", "--slide_gem",  action="store_true", dest="slide_gem", default=False, help="gem file for whole slide")
    # ArgParser.add_argument("-r", "--res_range", dest="res_range", required=False, default=np.around(np.arange(0.3,0.9,0.02), 3), type=float,help="resolution range for clustering")
    (para, args) = ArgParser.parse_known_args()
    print(para,'\n')
    #print(type(para.bin_size))
    
    num_gpus = torch.cuda.device_count()
    print("number of GPUs：", num_gpus)
    if num_gpus!=0:
        current_gpu = torch.cuda.current_device()
        print("The GPU index currently in use：", current_gpu)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print(f'CUDA available, using {torch.cuda.get_device_name(device)}.')
    
    name = str(para.bin_size)+'_crop_',str(para.crop_size)
    name = "".join(name)
    label_txt_file=str(name+'_label.txt')
    adata_file=str(name+'_adata.h5ad')
    
    rna_adata = sc.read_h5ad(para.input)
    
    if os.path.exists(os.path.join(para.output,label_txt_file)):
        print('The image tile has been completed. %s, read it directly' % os.path.join(para.output,label_txt_file)) 
    else:
        #img = cv2.imread(para.image)
        image_raw = Image.open(para.image)
        print(type(image_raw), image_raw.size)
        if para.slide_gem:
        #if int(rna_adata.obs.x.max()-rna_adata.obs.x.min()+para.bin_size)< img.size[0]:
            harfbin = int(para.bin_size)/2
            cbox = (int(rna_adata.obs.x.min()-harfbin), int(rna_adata.obs.y.min()-harfbin), int(rna_adata.obs.x.max()+harfbin), int(rna_adata.obs.y.max()+harfbin))
            image_added = image_raw.crop(cbox)  
        else:
            image_added = image_raw.copy()
        
        if 'spatial'  not in rna_adata.uns:
            print("add image to adata.uns['spatial']")
            rna_adata = add_img2adata(image_added,rna_adata,bin_size = para.bin_size,library_id = "cancer",spatial_key = "spatial")
        else:
            pass

        tile_path = os.path.join(para.output,name)
        mkdir(tile_path)
        crop_size = int(para.crop_size)  
        rna_adata = image_crop(image_added,rna_adata,save_path=tile_path,crop_size=crop_size,verbose=False)
        # rna_adata = image_crop(image_raw,rna_adata,save_path=tile_path,crop_size=crop_size,verbose=False)
        plot_spatial_with_img(rna_adata,para.output,file='rna_spatial_with_img.png')

        save_data(rna_adata,para.output,label_txt_file=label_txt_file,adata_file=adata_file)
        print('=======================finish generate tiles=============================')
        
    if os.path.exists(os.path.join(para.output,'img_feat.pkl')):
        print('The image feature extract has been completed. %s, read it directly' % os.path.join(para.output,'img_feat.pkl')) 
        feat_outputs = pd.read_pickle(os.path.join(para.output,'img_feat.pkl'))
    else:
        feature_extractor = load_model(para.model_name)
        # if torch.cuda.device_count() > 1:
        #     print("Turn on parallelism: use multiple GPUs for training")
        #     feature_extractor = torch.nn.DataParallel(feature_extractor)
        print('The device where the model is located:',next(feature_extractor.parameters()).device)

        label_txt_path = os.path.join(para.output,label_txt_file)
        feat_outputs = extract_img_feat(label_txt_path,extractor=feature_extractor)
        save_img_feat(feat_outputs,para.output,feat_file = 'img_feat.pkl')
        print('=======================finish generate image feature=============================')
    
    img_adata = generate_adata(feat_outputs.values,spatial=rna_adata.obsm['spatial'])
    print(f"Clustering H&E image feature at number of clusters: {para.num_cluster}")
    resolution, img_adata = find_res_binary(img_adata, resolution_min=0.1, resolution_max=1.2, num_clusters=para.num_cluster,key_added='cluster')
    print(f"Final Resolution: {resolution}")
    plot_spatial(img_adata,para.output,title ='H&E morphology featuer domain',group='cluster')

    img_adata.write(os.path.join(para.output,'img_adata.h5ad'))
    
if __name__ == "__main__":
    main()


import torch
from tqdm import tqdm
import random
import numpy as np
from sklearn.cluster import KMeans
from torch.nn.parameter import Parameter
from utils import *
import gc

class train():
    def __init__(self, img_tensor, rna_tensor, edge_index, model, custom_decoder = True,n_epochs=100, opt="adam", lr=0.0001, weight_decay=0.0001, save_att=False, verbose=True):#, attn_out = None
        self.img_tensor = img_tensor
        self.rna_tensor = rna_tensor
        self.edge_index = edge_index
        self.model = model
        self.custom_decoder = custom_decoder
        self.n_epochs = n_epochs
        self.save_att = save_att
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if opt=="sgd":
            self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr, momentum=0.9)
        elif opt=="adam":
            self.optimizer = torch.optim.Adam(self.model.parameters(),lr=lr, weight_decay=weight_decay)  

    def train_ae(self, gradient_clipping=5.):
        print('tarining AE')
        loss_list = []
        pbar = tqdm(range(1, self.n_epochs+1),desc='Training model...')
        current_memory = torch.cuda.memory_allocated() / 1024**2  # 转换为MB
        for epoch in pbar:
            self.model.train()
            self.optimizer.zero_grad()
            z, mean, logvar, x_hat, img_emb,rna_emb, attn_weight = self.model(self.img_tensor, self.rna_tensor, self.edge_index)
            if self.custom_decoder == True:
                concat_tensor = torch.cat((self.img_tensor,self.rna_tensor),1)
                recon_loss = self.model.recon_loss(x_hat,concat_tensor)
                # inner_loss = self.model.innerproduct_loss(z,self.edge_index)
                kl_loss = self.model.kl_loss(mean, logvar)
                loss = recon_loss + (1 / self.img_tensor.shape[0]) * kl_loss #+ inner_loss
            else:
                inner_loss = self.model.innerproduct_loss(z,self.edge_index)
                kl_loss = self.model.kl_loss(mean, logvar)
                loss = inner_loss + (1 / self.img_tensor.shape[0]) * kl_loss
            loss_list.append(loss.item())
            loss.backward()
            #print(loss)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), gradient_clipping)
            self.optimizer.step()

            gc.collect()
            torch.cuda.empty_cache()

            if epoch % 2 == 0:
                pbar.set_postfix_str("loss: {:.4f}".format(loss.item()))
            if self.save_att == True:
                # print('attention weiget saved at: %s' % (attn_out))
                # if not os.path.exists(attn_out):  
                #     os.makedirs(attn_out)
                self.save_attn_weight(epoch, attn_weight, inter=10, out_dir='./')

                #tqdm.set_description("loss: {:.4f}".format(loss.item()))
                #tqdm.set_postfix(loss=loss.item())
                #tqdm.set_postfix_str("loss: {:.4f}".format(loss.item()))

#         self.model.eval()
#         with torch.no_grad():
#             z, mean, logvar, x_hat, attn_weight = self.model(self.img_tensor, self.rna_tensor, self.edge_index)
#         vae_emb = z.to('cpu').detach().numpy()

#         return vae_emb,loss_list
        return loss_list, img_emb, rna_emb, attn_weight
    
    def train_dec(self, init = "kmeans", init_spa = True, n_cluster = 10, n_neighbors = 20, max_epochs = 100, update_interval = 3, tol = 1e-5, alpha = 0.9):
        loss_list, img_emb, rna_emb, attn_weight = self.train_ae()
        x = get_embedding(self.img_tensor, self.rna_tensor, self.edge_index, self.model)
        
        self.model.dec = Parameter(torch.Tensor(n_cluster, x.shape[1])).to(self.device)
        print('training use DEC')
        # self.model.dec = Parameter(torch.Tensor(n_cluster, x.shape[0])).to(self.device)
        # torch.nn.init.xavier_normal_(self.dec.data)
            
        if torch.is_tensor(x) :
            x_tensor = x.to(self.device)
        elif  isinstance(x, pd.DataFrame):
            x_array = x.values.astype(np.float32)
            x_tensor = torch.tensor(x_array).to(self.device)
        else:
            x_tensor = torch.tensor(x).to(self.device)

        if init=="kmeans":
            print(str("Initializing cluster centers with kmeans, n_clusters known: "+str(n_cluster)))
            # self.n_clusters=n_clusters
            kmeans = KMeans(n_cluster, n_init=20)

            if init_spa:
                #------Kmeans use exp and spatial
                y_pred = kmeans.fit_predict(x)
            else:
                #------Kmeans only use exp info, no spatial
                concat_tensor = torch.cat((self.img_tensor,self.rna_tensor),1)
                y_pred = kmeans.fit_predict(concat_tensor.numpy())  
                
        elif init=="leiden":
            print("Initializing cluster centers with leiden, resolution = ", res)
            if init_spa:
                adata=sc.AnnData(x)
            else:
                concat_tensor = torch.cat((self.img_tensor,self.rna_tensor),1)
                adata=sc.AnnData(concat_tensor.numpy())
            sc.pp.neighbors(adata, n_neighbors=n_neighbors)
            
            res = choose_res(adata, cluster_num = n_cluster ,res_range = np.around(np.arange(0.3,0.9,0.04), 3), determine_clus_num = True)
            adata = definite_res(adata,res,'./',plot_file='init_spatial_plot.png',title='init_leiden')
            try :
                select_res(adata,res,method='leiden',plot=True,title='init_plot')
            except :
                print('try select_res error')
            y_pred=adata.obs['leiden'].astype(int).to_numpy()
            
        y_pred_last = y_pred

        Group=pd.Series(y_pred,index=np.arange(0,x.shape[0]),name="Group")
        Mergefeature=pd.concat([pd.DataFrame(x),Group],axis=1) #detach().numpy()
        cluster_centers=np.asarray(Mergefeature.groupby("Group").mean())
        #print(cluster_centers.shape,cluster_centers)
        
        self.model.dec.data.copy_(torch.Tensor(cluster_centers)).to(self.device)
        self.model.train()
        for epoch in range(max_epochs):
            if epoch%update_interval == 0:
                zq, _, _, _, _, _, _ = self.model.forward(self.img_tensor, self.rna_tensor, self.edge_index, atten_model = True)
                
                q = 1.0 / ((1.0 + torch.sum((zq.unsqueeze(1) - self.model.dec)**2, dim=2) / alpha) + 1e-8)
                q = q**(alpha+1.0)/2.0
                q = q / torch.sum(q, dim=1, keepdim=True)
                
                p = self.target_distribution(q).data
            # else :
            #     p = p
            self.optimizer.zero_grad()
            z, mean, logvar, x_hat, img_emb,rna_emb, attn_weight = self.model.forward(self.img_tensor, self.rna_tensor, self.edge_index, atten_model = True)
            if self.custom_decoder == True:
                concat_tensor = torch.cat((self.img_tensor,self.rna_tensor),1)
                recon_loss = self.model.recon_loss(x_hat,concat_tensor)
                kl_loss = self.model.kl_loss(mean, logvar)
                loss = 10*recon_loss + (1 / self.img_tensor.shape[0]) * kl_loss
            else:
                inner_loss = self.model.innerproduct_loss(z,self.edge_index)
                kl_loss = self.model.kl_loss(mean, logvar)
                loss = inner_loss + (1 / self.img_tensor.shape[0]) * kl_loss
            # total_loss = loss + dec_kl_loss
            loss.backward()
            self.optimizer.step()
            if epoch%10==0:
                print("Epoch ", epoch, " loss:",loss) 

            #Check stop criterion
            y_pred = torch.argmax(q, dim=1).data.cpu().numpy()
            delta_label = np.sum(y_pred != y_pred_last).astype(np.float32) / x.shape[0]
            y_pred_last = y_pred
            if epoch>0 and (epoch-1)%update_interval == 0 and delta_label < tol:
                print('delta_label ', delta_label, '< tol ', tol)
                print("Reach tolerance threshold. Stopping training.")
                print("Total epoch:", epoch)
                break
        
        return loss_list, img_emb, rna_emb, attn_weight
                
    def target_distribution(self, q):
        #weight = q ** 2 / q.sum(0)
        #return torch.transpose((torch.transpose(weight,0,1) / weight.sum(1)),0,1)e
        p = q**2 / torch.sum(q, dim=0)
        p = p / torch.sum(p, dim=1, keepdim=True)
        
        return p

    def save_attn_weight(self, epoch, attn_weight, inter=10, out_dir='./'):
        import pickle
        """Save all the networks to the disk.

        Parameters:
            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        """
        if epoch % inter == 0:
            save_filename = 'epoch_%s_att_weight.pkl' % (epoch)
            with open(os.path.join(out_dir, save_filename), 'wb') as file:
                pickle.dump(attn_weight, file)

    def predict(self, ):
        self.model.eval()
        with torch.no_grad():
            z, mean, logvar, x_hat, img_emb,rna_emb, attn_weight = self.model(self.img_tensor, self.rna_tensor, self.edge_index)
        emb = z.to('cpu').detach().numpy()

        return emb


def get_embedding(img_tensor, rna_tensor, edge_index, model):
    model.eval()
    with torch.no_grad():
        z, mean, logvar, x_hat, img_emb,rna_emb, attn_weight = model(img_tensor, rna_tensor, edge_index)
    vae_emb = z.to('cpu').detach().numpy()

    return vae_emb


import scanpy as sc
import pandas as pd
import sklearn.neighbors
import numpy as np
import anndata
import random
import torch
import os
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler,StandardScaler,scale

def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def generate_adata(X, spatial=None, n_neighbors=20, n_pcs=30):
    adata = sc.AnnData(X)
    if spatial is not None:
        #adata.obsm['spatial'] = np.array(obs.loc[:,["img_x","img_y"]])
        adata.obsm['spatial'] = np.array(spatial)
    sc.pp.neighbors(adata,use_rep='X',n_neighbors=n_neighbors, n_pcs=n_pcs)
    return adata

def find_res_binary(adata, resolution_min, resolution_max, num_clusters, method = 'leiden',key_added = 'cluster'):
    # Use binary search to find the resolution parameter that satisfies the condition
    if method == 'leiden':
        sc.tl.leiden(adata, resolution=resolution_max, key_added=key_added)
    elif method == 'louvain':
        sc.tl.louvain(adata, resolution=resolution_max, key_added=key_added)
    if int(len(np.unique(adata.obs[key_added]))) < int(num_clusters):
        resolution_max = resolution_max+1
        print('Number of clusters at the maximum resolution is less than %s, adjust maximum resolution to %s' % (num_clusters,resolution_max))
    
    while resolution_min <= resolution_max:
        # Perform Leiden clustering
        resolution = (resolution_min + resolution_max) / 2
        sc.tl.leiden(adata, resolution=resolution, key_added=key_added)

        # Check the number of unique clusters in the clustering result
        unique_clusters = np.unique(adata.obs[key_added])
        
        # Print the current progress
        print(f"Resolution: {resolution_min, resolution_max, resolution}, Unique Clusters: {len(unique_clusters)}")

        if len(unique_clusters) == num_clusters:
            break
        elif len(unique_clusters) < num_clusters:
            resolution_min = resolution
            # resolution = (resolution + resolution_max) / 2
        else:
            resolution_max = resolution
            # resolution = (resolution_min + resolution) / 2
            
        if resolution_max - resolution_min < 1e-6:
            break
            
    return resolution, adata

def plot_spatial(adata,save_path,title=None,group='cluster',set_scale = 6):
    scale = (adata.obsm['spatial'][:,0].max()-adata.obsm['spatial'][:,0].min())/(adata.obsm['spatial'][:,1].max()-adata.obsm['spatial'][:,1].min())
    plt.rcParams["figure.figsize"] = (set_scale*scale,set_scale)
    if 'spatial' in adata.uns:
        sc.pl.spatial(adata, img_key="hires", color=[group],basis="spatial", title = title,size=1, alpha_img=0.8, alpha=0.5, show=False)
    else:
        sc.pl.embedding(adata, color=[group],basis="spatial", title = title,size=20, show=False)
    plt.savefig(os.path.join(save_path,str(title+'.png')), bbox_inches='tight')
    plt.show()

def cal_spatial_net(adata, rad_cutoff=None, k_cutoff=None, map_id=True, verbose=True):
    #assert(model in ['Radius', 'KNN'])
    if verbose:
        print('Calculating spatial location graph......')
    coor = pd.DataFrame(adata.obsm['spatial'])
    coor.index = adata.obs.index
    #coor.columns = ['imagerow', 'imagecol']

    if rad_cutoff is not None:
        nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor)
        distances, indices = nbrs.radius_neighbors(coor, return_distance=True)
        KNN_list = []
        for it in range(indices.shape[0]):
            KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))
    
    if k_cutoff is not None:
        nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff+1).fit(coor)
        distances, indices = nbrs.kneighbors(coor)
        KNN_list = []
        for it in range(indices.shape[0]):
            KNN_list.append(pd.DataFrame(zip([it]*indices.shape[1],indices[it,:], distances[it,:])))

    KNN_df = pd.concat(KNN_list)
    KNN_df.columns = ['Cell1', 'Cell2', 'Distance']

    Spatial_Net = KNN_df.copy()
    if verbose:
        print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))
        print('%.4f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))
        
    if map_id:
        #Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
        id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
        Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
        Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
        return Spatial_Net,id_cell_trans
    else:
        return Spatial_Net
    
def edgelist2adj(edgelist, weight = None):
    import networkx as nx
    G = nx.Graph()

    # Add the connections to the graph
    for _, row in edgelist.iterrows():
        cell1 = row['Cell1']
        cell2 = row['Cell2']
        
        if weight is not None:
            G.add_edge(cell1, cell2, weight=weight)
        else:
            G.add_edge(cell1, cell2)

    # Convert the graph to a connection matrix
    adjacency_matrix = nx.to_numpy_array(G)
    
    return adjacency_matrix

    
def pruning_knn(knn_df,clustet_df):
    filtered_df = pd.DataFrame()
    if clustet_df.shape[1] == 2:
        print('pruning use 2 modality')
        for index, row in knn_df.iterrows():
            cell1 = row['Cell1']
            cell2 = row['Cell2']

            rna = clustet_df.loc[cell1,'cluster_x']==clustet_df.loc[cell2,'cluster_x']
            img = clustet_df.loc[cell1,'cluster_y']==clustet_df.loc[cell2,'cluster_y']
            #print(cell1,cell2,rna,img)
            if (rna or img):
                filtered_df = filtered_df._append(row, ignore_index=True)
    elif clustet_df.shape[1] == 1:
        print('pruning use 1 modality')
        for index, row in knn_df.iterrows():
            cell1 = row['Cell1']
            cell2 = row['Cell2']

            col_name = clustet_df.columns
            #print(cell1,cell2,int(clustet_df.loc[cell1,col_name]),int(clustet_df.loc[cell2,col_name]),(clustet_df.loc[cell1,col_name]==clustet_df.loc[cell2,col_name]).any())
            if (clustet_df.loc[cell1,col_name]==clustet_df.loc[cell2,col_name]).any():
                filtered_df = filtered_df._append(row, ignore_index=True)
    return filtered_df 

def index_knn(knn_df,id_cell_trans):
    trans = {value: key for key, value in id_cell_trans.items()}
    index_knn_df = knn_df.copy()
    index_knn_df ['Cell1'] = index_knn_df ['Cell1'].map(trans)
    index_knn_df ['Cell2'] = index_knn_df ['Cell2'].map(trans)
    return index_knn_df

def get_cluster_id(adata,res,method='leiden'):
    adata = select_res(adata,res,method,plot=False,title=None)
    cluter_id = pd.DataFrame(adata.obs.loc[:,'cluster'])
    return cluter_id
    
def purning_by_cluster(knn_df,rna_data,img_data,init_res):
    if isinstance(rna_data, anndata.AnnData):
        rna_adata = rna_data
    else:
        rna_adata = generate_adata(rna_data)
        
    if isinstance(img_data, anndata.AnnData):
        img_adata = img_data
    else:
        img_adata = generate_adata(img_data)
    
    if rna_data is not None and img_data is not None:
        rna_cluster = get_cluster_id(rna_adata,res=init_res,method='leiden')
        img_cluster = get_cluster_id(img_adata,res=init_res,method='leiden')
        clus_df = pd.merge(rna_cluster,img_cluster,left_index=True,right_index=True, how='left')
    elif rna_data is None:
        clus_df = get_cluster_id(img_adata,res=init_res,method='leiden')
    elif img_data is None:
        clus_df = get_cluster_id(rna_adata,res=init_res,method='leiden')
    
    prun_knn_df = pruning_knn(knn_df,clus_df)
    print('edges of raw knn graph: %s. edges of purning knn graph: %s' % (knn_df.shape[0],prun_knn_df.shape[0]))
    return prun_knn_df

def extract_rna_feat(adata,num_feat=2048,dim_reduction_method = 'high_var'):
    if type(adata.X) is np.ndarray:
        data = adata.X
    else:
        data = adata.X.toarray()
    if dim_reduction_method == 'pca':
        pca = PCA(n_components=num_feat)
        rna_df = pca.fit_transform(data)
    elif dim_reduction_method == 'high_variable':
        sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=num_feat)
        rna_df = adata[:, adata.var['highly_variable']].X
    elif dim_reduction_method == 'high_var':
        variances = np.var(data, axis=0)
        top_n = np.argsort(variances)[-num_feat:]
        high_var_genes = adata.var.index[top_n]
        rna_df = adata[:,high_var_genes].X.toarray()
    elif dim_reduction_method == None:
        rna_df = adata.X
    
    return(pd.DataFrame(rna_df))

def scale_data(data,scaler='zscore'):
    if scaler=='zscore':
        scaler = StandardScaler()
    elif scaler=='ninmax':
        scaler = MinMaxScaler()
    
    scaled_data = scaler.fit_transform(data)
    return scaled_data

def load_graph_edgelist(edgelist_path):
    edgelist = []
    with open(edgelist_path, 'r') as edgelist_file:
        edgelist = [(int(item.split()[0]), int(item.split()[1])) for item in edgelist_file.readlines()]
    return edgelist

def plot_loss_curve(loss_values,out_dir):
    epochs = range(1, len(loss_values) + 1)
    plt.plot(epochs, loss_values, 'b', label='Training Loss')
    plt.title('Loss Curve')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(os.path.join(out_dir,'Train_loss_curve.png'))
    plt.show()

def makedir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)
    else:
        pass

def find_res_sigmoid(adata, cluster_range=(12,13), by=0.1, res=1, verbose=False):
    if verbose:
        print("Find suitable resolution, start with", res)

    if isinstance(cluster_range, int):
        cluster_range = [cluster_range, cluster_range]
    elif isinstance(cluster_range, tuple) and len(cluster_range) == 2:
        cluster_range = cluster_range

    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    x = -np.log(10/res - 1)
    plus_counter = minus_counter = 0
    n_cluster = 1

    while n_cluster < cluster_range[0] or n_cluster > cluster_range[1]:
        sc.tl.leiden(adata, resolution=res, key_added='clusters')
        n_cluster = len(np.unique(adata.obs['clusters']))

        if n_cluster < cluster_range[0]:
            x = x + by
            plus_counter = plus_counter + 1
        elif n_cluster > cluster_range[1]:
            x = x - by
            minus_counter = minus_counter + 1
        else:
            break

        res = round(sigmoid(x) * 10, 3)

        if plus_counter and minus_counter:
            print()
            raise ValueError("Specific cluster range was skipped! Try expanding the cluster range or reducing the resolution step size.")

        if verbose:
            print("resolution", res, "... ",n_cluster, "clusters.")

    adata.uns['best_resolution'] = res

    if verbose:
        print("Final resolution:", res, "with", n_cluster, "clusters.")

    return adata

def select_res(adata,res,method='leiden',plot=False,save_path=None,title=None):
    if method == 'leiden':
        sc.tl.leiden(adata, resolution=res, key_added='cluster')
    elif method == 'louvain':
        sc.tl.louvain(adata, resolution=res, key_added='cluster')
    #print(adata.obs.cluster.unique())
        
    if plot == True:
        if 'spatial' in adata.obsm.keys():
            scale = (adata.obsm['spatial'][:,0].max()-adata.obsm['spatial'][:,0].min())/(adata.obsm['spatial'][:,1].max()-adata.obsm['spatial'][:,1].min())
            plt.rcParams["figure.figsize"] = (8*scale,8)
            sc.pl.embedding(adata, color=["cluster"],basis="spatial", title = title,size=20, show=False)
        else:
            sc.tl.umap(adata)
            sc.pl.embedding(adata, color=["cluster"], title = title,size=20, show=False)

        # plt.savefig('./resolution %s for purning.png' % res, bbox_inches='tight')
        plt.savefig(os.path.join(save_path,str(title+'.png')), bbox_inches='tight')
        plt.show()
    return adata

def find_res_step(adata, cluster_num = 12, res_range = np.around(np.arange(0.3,0.9,0.04), 3), determine_clus_num = True,criterion = 'CH_score'):
    if determine_clus_num == True:
        print('choose res by number of clusters')
        # adata = find_clusters(adata, cluster_range=cluster_num, by=0.1, res=1, verbose=False)
        ## choose res by number of cluster
        cluster_num = cluster_num
        found = False

        for res in res_range:
            sc.tl.leiden(adata, resolution=res)#,key_added=str("res"+str(res))
            num = len(adata.obs.leiden.unique())
            print(str(res),":", num)
            if num == cluster_num:
                print('res',str(res),'reach number of cluser',str(cluster_num))
                found = True
                return res

        if not found:
            print("No suitable resolution found") 
            return res
            
    else:
        print('choose res by best cluster scores')
        ## choose res by best cluster scores
        #CH_scores = DB_scores = sil_scores = []
        res_outs = []

        for res in res_range:
            sc.tl.leiden(adata, resolution=res)#,key_added=str("res"+str(res))
            labels=vgae.obs['leiden']
            CH_score = metrics.calinski_harabasz_score(emb, labels)
            DB_score = metrics.davies_bouldin_score(emb, labels)
            sil_score = metrics.silhouette_score(emb, labels)
            num = str(len(vgae.obs.leiden.unique()))
            print(str(str(res)+"_"+num),":",DB_score,CH_score,sil_score)
            res_out = [res,num,DB_score,CH_score,sil_score]
            res_outs.append(res_out)

        res_df = pd.DataFrame(res_outs,columns=['res','DB_score','CH_score','sil_score'])
        criterion = criterion
        if criterion in ['CH_score','sil_score']:
            best_res = res_df.iloc[res_df[criterion].idxmax(),0]
            print('best res choosen by',criterion,':',str(best_res))
        elif criterion == 'DB_score':
            best_res = res_df.iloc[res_df[criterion].idxmin(),0]
            print('best res choosen by',criterion,':',str(best_res))
        return res
        
            
def definite_res(adata,res,save_path=None,plot_file='choose_res_spatial_plot.png',title=None):
    sc.tl.leiden(adata, resolution=res)
    #print(adata.obs.leiden.unique())
    scale = (adata.obsm['spatial'][:,0].max()-adata.obsm['spatial'][:,0].min())/(adata.obsm['spatial'][:,1].max()-adata.obsm['spatial'][:,1].min())
    plt.rcParams["figure.figsize"] = (8*scale,8)
    sc.pl.embedding(adata, basis="spatial", color=["leiden"],title = title,size=20, show=False)
    #plt.savefig(os.path.join(save_path,plot_file))
    #plt.show()
    return adata




import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
import numpy as np
import pandas as pd
import scanpy as sc
import torch
import matplotlib.pyplot as plt
from torch_geometric.utils import to_undirected
from torch_geometric.data import Data
from sklearn.decomposition import PCA
from utils import *
from models import *
from trainer import *
import sys
import random
import pickle
import warnings
warnings.filterwarnings("ignore")

###### Version and Date
PROG_VERSION = '1.0'
PROG_DATE = '2023-09-20'

###### Usage
usage = '''

Version %s  by Luo Bingying  %s

Usage: %s --rna_data <adata_file> --image_data <image_file> -o <output_dir> [...]

''' % (PROG_VERSION, PROG_DATE, os.path.basename(sys.argv[0]))

def main():
    import argparse
    def int_or_none(value):
        if value is None:
            return None
        return int(value)

    ArgParser = argparse.ArgumentParser(usage=usage,description='run model')
    ArgParser.add_argument("--version", action="version", version=PROG_VERSION)
    ArgParser.add_argument("--sessioninfo", dest="sessioninfo",required=False, action='store_true', default=False, help="Print conda list.")
    ArgParser.add_argument("--verbose", dest="verbose",required=False, action='store_true', default=False, help="Print cuda memory use.")

    ArgParser.add_argument("-t", "--toy", action="store_true", dest="toy_data", default=False, help="Whether to use toy data.")
    ArgParser.add_argument("--rna_data", action="store", dest="rna_data", required=False, type=str, help="Path of rna feature file.")
    ArgParser.add_argument("--image_data", action="store", dest="image_data", required=False, type=str, help="Path of H&E image feature file.")
    ArgParser.add_argument("-o", "--output", action="store", dest="output", required=True, type=str, help="Output folder.")
    ArgParser.add_argument("--epochs", dest="epochs",required=True, type=int, default=100, help="Number of training epochs.")
    ArgParser.add_argument("--lr", dest="lr", required=False, type=float, default=0.0001, help="Learning rate for training.")
    ArgParser.add_argument("--opt", dest="opt", required=False, type=str, default="adam", help="Optimizer for training.")
    ArgParser.add_argument("--dim_reduction_method", dest="dim_reduction_method",required=False, type=str, default="high_var", help="RNA data dim reduction method.")
    ArgParser.add_argument("--scale", dest="scale",required=False, type=str, default="zscore", help="Feature normalization method.")
    ArgParser.add_argument("--feat_pca", dest="feat_pca",required=False, action='store_true', default=False, help="PCA when ectract feature.")
    ArgParser.add_argument("--radiu_cutoff", dest="radiu_cutoff",required=False, type=int_or_none, default=None,help="Radiu for KNN graph.")
    ArgParser.add_argument("--knn", dest="knn",required=False, type=int_or_none, default=None, help="K for KNN graph.")
    ArgParser.add_argument("--purning", dest="purning_knn",required=False, action='store_true', default=False, help="Prune the knn graph")
    ArgParser.add_argument("--num_heads", dest="num_heads",required=False, type=int, default=1, help="Number of attention heads for cross attention layer.")
    ArgParser.add_argument("--hidden_dims", dest="hidden_dims",required=False, type=int, default=[1024, 512], help="Hidden dimension for each hidden layer.", nargs="*")
    ArgParser.add_argument("--latent_dim", dest="latent_dim",required=False, type=int, default=100, help="Latent dimension (output dimension for node embeddings).")
    ArgParser.add_argument("--brief_att", dest="brief_att",required=False, action='store_true', default=False, help="Attention layer with fewer parameters.")
    ArgParser.add_argument("--att_mode", dest="att_mode",required=False, action='store', default='cross', help="Attention layer type.")
    ArgParser.add_argument("--decoder", dest="customize_decoder",required=False, action='store_true', default=False, help="Construct the neural network decoder.")
    ArgParser.add_argument("--gnn_type", dest="gnn_type",required=False, type=str, default="GCN", help="Graph Neural Network type.")
    ArgParser.add_argument("--rna_weight", dest="rna_weight",required=False, type=int, default=1, help="Weight of rna attention to concat.")
    ArgParser.add_argument("--img_weight", dest="img_weight",required=False, type=int, default=1, help="Weight of img attention to concat.")
    ArgParser.add_argument("--dec", action="store_true", dest="dec", default=False, help="Whether to add DEC for cluster.")
    ArgParser.add_argument("--n_cluster", dest="n_cluster",required=False, type=int, default=10, help="number of clustered categories.")
    # ArgParser.add_argument("--res_range", dest="res_range", required=False, default=np.around(np.arange(0.2,1.2,0.03), 3), type=float,help="resolution range for clustering.")
    
    (para, args) = ArgParser.parse_known_args()
    print(para)
    
    set_seed(77)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        print("number of GPUs：", num_gpus)
        current_gpu = torch.cuda.current_device()
        print("The GPU index currently in use：", current_gpu)
        print(f'CUDA available, using {torch.cuda.get_device_name(device)}.')
    else:
        print('CUDA not available, use CPU')
        
    out_dir = para.output
    makedir(out_dir)

    if para.toy_data:
        # Create the number of nodes and feature dimensions
        num_nodes = 100
        num_features = 2000

        img_tensor = torch.randn(num_nodes, num_features).double().to(device)
        rna_tensor = torch.randn(num_nodes, num_features).double().to(device)

        edge_index = torch.randint(num_nodes, (2, num_nodes * 2))
        edge_index = edge_index.to(device)

        # single_model_size=50
        # hidden_size=[32,16] 
        # latent_size=10 

    else:
        ###========== Feature processing of each modality ==========###
        img_feat = pd.read_pickle(para.image_data)
        rna_adata = sc.read_h5ad(para.rna_data)

        ### Single modality preprocessing
        rna_feat = extract_rna_feat(rna_adata,num_feat=2048,dim_reduction_method = para.dim_reduction_method)
        if not para.dim_reduction_method == 'pca':
            rna_feat = scale_data(rna_feat,scaler=para.scale)
        img_feat = scale_data(img_feat,scaler=para.scale)
        
        if para.feat_pca == True:
            print('use PCA when excrate single feature')
            pca = PCA(n_components=200)
            rna_feat = pca.fit_transform(rna_feat)
            img_feat = pca.fit_transform(img_feat)
            print('final rna feature shape: %s. image feature shape: %s' % (rna_feat.shape,img_feat.shape))

        rna_tensor = torch.from_numpy(np.array(rna_feat)).double().to(device)
        img_tensor = torch.from_numpy(np.array(img_feat)).double().to(device)
        
        print('input image feature shape: %s. \n input rna feature shape: %s.' % (img_tensor.shape,rna_tensor.shape))
        
        ### knn
        graph_file_path = "knn.txt"
        if os.path.exists(os.path.join(out_dir,'knn.txt')):
            print('The KNN file already exists in %s, read it directly' % os.path.join(out_dir,'knn.txt')) 
        else:
            knn_df,id_cell_trans = cal_spatial_net(rna_adata,rad_cutoff=para.radiu_cutoff,k_cutoff=para.knn, map_id=True)

            # Whether to prune
            if para.purning_knn:
                knn_df = purning_by_cluster(knn_df=knn_df,rna_data=rna_adata,img_data=img_feat,init_res=0.4)

            index_knn_df = index_knn(knn_df,id_cell_trans)
            index_knn_df  
            index_knn_df.iloc[:,0:2].to_csv(os.path.join(out_dir,'knn.txt'),index=False,sep=' ',header=False)

        edgelist = load_graph_edgelist(os.path.join(out_dir,'knn.txt'))
        edge_index = np.array(edgelist).astype(int).T
        edge_index = to_undirected(torch.from_numpy(edge_index).to(torch.long))
        edge_index = edge_index.to(device)

    single_model_size=img_tensor.shape[1]
    hidden_size=para.hidden_dims
    latent_size=para.latent_dim

    # Inspecting the data used for graph convolution
    data_obj = Data(edge_index=edge_index, x=img_tensor)
    print('Data inspection results for graph convolutional networks: %s. \n' % data_obj.validate(raise_on_error=True))

    ###========== CONSTRUCT ==========###
    import time
    start_time = time.time()
    
    if para.customize_decoder:
        decoder_hidden_size = hidden_size.copy()
        decoder_hidden_size.reverse()
        decoder_hidden_size.insert(0,latent_size)

        decoder = Decoder(output_size=single_model_size*2,hidden_sizes=decoder_hidden_size, gnn_type = para.gnn_type)
        fimodal = FinalModal(single_model_size, hidden_size, latent_size, num_heads=para.num_heads, brief_att = para.brief_att, attn_mode=para.att_mode,gnn_type = para.gnn_type, img_weight = para.img_weight, rna_weight = para.rna_weight, decoder = decoder)

    else:
        fimodal = FinalModal(single_model_size, hidden_size, latent_size, num_heads=para.num_heads, brief_att = para.brief_att, attn_mode=para.att_mode)

    fimodal = fimodal.double().to(device)
    fimodal.print_networks()
    current_memory = torch.cuda.memory_allocated() / 1024**2  # 转换为MB
    print(f"CUDA memory consumption at the current stage: {current_memory} MB")

    # def count_parameters(model):
    #     return sum(p.numel() for p in model.parameters() if p.requires_grad)
    # num_params = count_parameters(fimodal)
    # print("Number of trainable parameters: ", num_params)

    ###========== TRAINING ==========###
    torch.cuda.empty_cache()
    #attn_out_dir = os.path.join(out_dir,'attn_weight')
    trainer = train(img_tensor, rna_tensor, edge_index, fimodal, custom_decoder = para.customize_decoder,n_epochs=para.epochs, opt=para.opt, lr=para.lr, weight_decay=0.0001, save_att=False, verbose=True)#, attn_out = attn_out_dir
    if para.dec :
        loss_values, img_emb, rna_emb, attention_weight = trainer.train_dec(init = "kmeans", n_cluster = para.n_cluster, n_neighbors = 20, max_epochs = 100, update_interval = 10, tol = 1e-5, alpha = 0.9)
        flag = 'dec'
    else:
        loss_values, img_emb, rna_emb, attention_weight = trainer.train_ae()
        flag = 'ae'
    emb = trainer.predict()
    
    if para.verbose:
        # print(u'Current memory usage：%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))
        print(torch.cuda.max_memory_allocated())
        print(torch.cuda.memory_summary())
    torch.cuda.empty_cache()
    
    end_time = time.time()
    run_time = end_time - start_time

    # Save data
    emb = pd.DataFrame(emb)
    emb.to_pickle(os.path.join(out_dir,'embedding.pkl'))
    emb.to_csv(os.path.join(out_dir,'embedding.csv'))
    with open(os.path.join(out_dir,'att_weight.pkl'), 'wb') as file:
        pickle.dump(attention_weight, file)
    img_emb = pd.DataFrame(img_emb.cpu().detach().numpy())
    img_emb.to_pickle(os.path.join(out_dir,'img_embedding.pkl'))
    rna_emb = pd.DataFrame(rna_emb.cpu().detach().numpy())
    rna_emb.to_pickle(os.path.join(out_dir,'rna_embedding.pkl'))

    # Draw the loss curve
    loss_values = pd.DataFrame(loss_values)
    loss_values.to_csv(os.path.join(out_dir,'loss.csv'))
    try:
        plot_loss_curve(loss_values,out_dir)
    except:
        print('Error when plot_loss_curve')
    
    # clustering
    if not para.toy_data:
        emb.index = rna_adata.obs.index
        emb.to_pickle(os.path.join(out_dir,'embedding.pkl'))
        emb.to_csv(os.path.join(out_dir,'embedding.csv'))
        #spatial_coord = pd.read_csv('./spatial.csv',index_col=0)
        stereomm_adata = generate_adata(emb , spatial=rna_adata.obsm['spatial'])
        
        resolution, stereomm_adata = find_res_binary(stereomm_adata, resolution_min=0.1, resolution_max=1.2, num_clusters=para.n_cluster)
        print(f"Final Resolution: {resolution}")
        plot_spatial(stereomm_adata,out_dir,'StereoMM_domain')
        
        stereomm_adata.write(os.path.join(out_dir,'emb_adata.h5ad'))

    print('======= Finish!!!!!!======')
    print('RUNING TIME: %s' % (run_time))
    
    if para.sessioninfo:     
        os.system('pip list')
    
if __name__ == "__main__":
    main()
