import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix, issparse
import seaborn as sns
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import sparsefuncs
import torch
import torch.nn as nn
import torch.nn.functional as F


class SA_Fusion(nn.Module):
    def __init__(self, input_dim, out_dim):
        super(SA_Fusion, self).__init__()
        self.input_dim = input_dim
        
        # 定义线性变换矩阵
        self.W_q = nn.Linear(input_dim, input_dim)
        self.W_k = nn.Linear(input_dim, input_dim)
        self.W_v = nn.Linear(input_dim, input_dim)
        
        # 输出线性层
        self.linear1 = nn.Linear(input_dim*2, input_dim)  # 第一层
        self.activation1 = nn.ReLU()
        self.linear2 = nn.Linear(input_dim, out_dim)         # 第二层

    def forward(self, gene_embeddings, image_embeddings):
        # 计算自注意力
        gene_q = self.W_q(gene_embeddings)
        image_k = self.W_k(image_embeddings)
        image_v = self.W_v(image_embeddings)

        # 计算注意力分数
        attention_scores = torch.bmm(gene_q, image_k.transpose(1, 2)) / (self.input_dim ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)

        # 加权求和
        attended_image_features = torch.bmm(attention_weights, image_v)

        # 融合基因和图像特征
        combined_features = torch.cat((gene_embeddings, attended_image_features), dim=-1)

        # 输出融合后的特征
        output_features = self.linear2(self.activation1(self.linear1(combined_features)))
        
        return output_features

def vit_grid_pooling(vit_output, grid_size=(3, 3), kernel_size=6, stride=4):
    batch_size, num_tokens, hidden_dim = vit_output.shape

    # 重新调整形状，(B, 196, D) -> (B, D, 14, 14)
    vit_output = vit_output.reshape(batch_size, 14, 14, hidden_dim).permute(0, 3, 1, 2)

    # 使用 unfold 提取补丁，(B, D, H_out * W_out * kernel_size * kernel_size)
    patches = vit_output.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
    
    # 调整形状为 (B, D, grid_h * grid_w, kernel_size * kernel_size)
    patches = patches.contiguous().view(batch_size, hidden_dim,
                                        grid_size[0], grid_size[1],
                                        -1)  # -1 自动计算补丁大小

    # 计算每个补丁的均值，(B, D, grid_h * grid_w)
    pooled = patches.mean(dim=-1)

    return pooled.view(batch_size, -1, hidden_dim)  # (B, grid_h * grid_w, D)

class Adapter(nn.Module):
    def __init__(self, c_in, reduction=4):
        super(Adapter, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(c_in, c_in // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c_in // reduction, c_in, bias=False),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.fc(x)
        return x


def complete_masking(batch, p, n_tokens):
    
    padding_token = 1
    cls_token = 3

    indices = batch['tokenized_gene']

    indices = torch.where(indices == 0, torch.tensor(padding_token), indices) # 0 is originally the padding token, we change it to 1
    batch['tokenized_gene'] = indices

    mask = 1 - torch.bernoulli(torch.ones_like(indices), p) # mask indices with probability p
    spatial_mask = 1 - torch.bernoulli(torch.ones_like(indices), 1)
    
    masked_indices = indices * mask # masked_indices 
    masked_indices = torch.where(indices != padding_token, masked_indices, indices) # we just mask non-padding indices
    mask = torch.where(indices == padding_token, torch.tensor(padding_token), mask) # in the model we evaluate the loss of mask position 0
    spatial_mask = torch.where(indices == padding_token, torch.tensor(padding_token), spatial_mask) # in the model we evaluate the loss of mask position 0
    # so we make the mask of all PAD tokens to be 1 so that it's not taken into account in the loss computation
    
    # Notice for the following 2 lines that masked_indices has already not a single padding token masked
    masked_indices = torch.where(indices != cls_token, masked_indices, indices) # same with CLS, no CLS token can be masked
    mask = torch.where(indices == cls_token, torch.tensor(padding_token), mask) # we change the mask so that it doesn't mask any CLS token
    spatial_mask = torch.where(indices == cls_token, torch.tensor(padding_token), spatial_mask) # we change the mask so that it doesn't mask any CLS token
    
    # 80% of masked indices are masked
    # 10% of masked indices are a random token
    # 10% of masked indices are the real token

    random_tokens = torch.randint(10, n_tokens, size=masked_indices.shape, device=masked_indices.device)
    random_tokens = random_tokens * torch.bernoulli(torch.ones_like(random_tokens)*0.1).type(torch.int64) 

    masked_indices = torch.where(masked_indices == 0, random_tokens, masked_indices) # put random tokens just in the previously masked tokens

    same_tokens = indices.clone()
    same_tokens = same_tokens * torch.bernoulli(torch.ones_like(same_tokens) * 0.1).type(torch.int64)

    masked_indices = torch.where(masked_indices == 0, same_tokens, masked_indices) # put same tokens just in the previously masked tokens

    batch['masked_indices'] = masked_indices
    batch['mask'] = mask
    batch['spatial_mask'] = spatial_mask
    attention_mask = (masked_indices == padding_token)
    batch['attention_mask'] = attention_mask.type(torch.bool)

    return batch


def set_seed(seed):
    """
    Sets the seed for all libraries used.
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available:
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    os.environ["CUBLAS_WORKSPACE_CONFIG"]=":16:8"

def sf_normalize(X):
    X = X.copy()
    counts = np.array(X.sum(axis=1))
    # avoid zero devision error
    counts += counts == 0.
    # normalize to 10000. counts
    scaling_factor = 10000. / counts

    if issparse(X):
        sparsefuncs.inplace_row_scale(X, scaling_factor)
    else:
        np.multiply(X, scaling_factor.reshape((-1, 1)), out=X)

    return X



def _sub_tokenize_data(x: csr_matrix, max_seq_len: int = -1, aux_tokens: int = 30):
    scores_final = np.empty((x.shape[0], max_seq_len if max_seq_len > 0 else x.shape[1]))
    
    for i in range(x.shape[0]):
        start_idx = x.indptr[i]  # Start of the non-zero elements for row i
        end_idx = x.indptr[i + 1]  # End of the non-zero elements for row i
        nonzero_indices = x.indices[start_idx:end_idx]  # Indices of non-zero elements
        nonzero_data = x.data[start_idx:end_idx]  # Values of non-zero elements
        
        # sorted_indices = nonzero_indices[np.argsort(-nonzero_data)][:max_seq_len]
        sorted_idx = np.argsort(-nonzero_data)[:max_seq_len]
        sorted_indices = nonzero_indices[sorted_idx]  # 按排序后的顺序获取索引
        sorted_indices = sorted_indices + aux_tokens  # Adjust for auxiliary tokens
        
        if max_seq_len > 0:
            scores = np.zeros(max_seq_len, dtype=np.int32)
        else:
            scores = np.zeros(x.shape[1], dtype=np.int32)

        # 填充排序后的索引和对应的原始数值
        scores[:len(sorted_indices)] = sorted_indices.astype(np.int32)
        
        # 将结果存入最终数组
        scores_final[i, :] = scores
    
    # 返回排序后的索引和对应的原始数值
    return scores_final


def tokenize_data(x: np.array, max_seq_len: int = None, aux_tokens: int = None):
    """Tokenize the input gene vector to a vector of 32-bit integers."""
    if type(x) == np.matrix:
        x = csr_matrix(x)
    scores_final = _sub_tokenize_data(x.tocsr(), max_seq_len, aux_tokens)

    return scores_final.astype('i4')


def group_spots_into_unique_batches(adata_slide, batch_size=9):
    """
    对某个 slide 的 AnnData 对象，基于空间坐标生成互不重叠的 mini-batch
    每个 mini-batch 由 batch_size 个彼此最近的 spot 组成。
    """
    array_row = np.array(adata_slide.obs['array_row'])
    array_col = np.array(adata_slide.obs['array_col'])

    spatial_coords = np.column_stack((array_row, array_col)).astype(int)
    n_spots = spatial_coords.shape[0]
    unassigned = set(range(n_spots))
    batches = []

    # 构造 NearestNeighbors 对象，设置足够多的邻居数
    nbrs = NearestNeighbors(n_neighbors=n_spots)
    nbrs.fit(spatial_coords)

    # 贪心式分组：只要剩余的点足够凑成一个完整 mini-batch，就严格按不重复的方式分组
    while len(unassigned) >= batch_size:
        # 从未分配集合中选取一个种子点
        seed = unassigned.pop()
        group = [seed]

        # 找出种子点在所有点中的排序顺序（按距离由近到远）
        distances, indices = nbrs.kneighbors([spatial_coords[seed]], n_neighbors=n_spots)
        indices = indices.tolist()

        # 从未分配集合中添加最近点，直到凑够 batch_size
        for idx in indices[0]:
            if len(group) >= batch_size:
                break
            if idx in unassigned:
                group.append(idx)
        # 将本组内的所有点从 unassigned 中移除，保证不重复
        for idx in group:
            unassigned.discard(idx)
        # 计算每个点到其他点的平均距离，找到最中心的点
        group_coords = spatial_coords[group]
        mean_distances = np.mean(
            np.sqrt(((group_coords[:, None, :] - group_coords[None, :, :]) ** 2).sum(axis=2)),
            axis=1,
        )
        most_central_idx = np.argmin(mean_distances)  # 找到平均距离最小的点索引
        group.insert(0, group.pop(most_central_idx))  # 将最中心的点移到第一个位置
        batches.append(group)

    # 对于剩余不足 batch_size 的 spot（唯一性无法满足 9 个）：
    if unassigned:
        # 此时剩余的点数不足以独立构成一个完整的 mini-batch，
        # 我们放宽限制，选取一个种子并利用最近邻补齐到 9 个（允许重复）
        remaining_list = list(unassigned)
        # 以剩余中第一个点为中心补齐
        seed = remaining_list[0]
        distances, indices = nbrs.kneighbors([spatial_coords[seed]], n_neighbors=batch_size)
        indices = indices.tolist()
        # 这里先把剩余所有点（保证它们至少出现一次）的索引加入
        group = list(unassigned)
        # 如果不足 batch_size，则遍历 indices（可能重复之前已给出的点）来补充满 9 个
        for idx in indices[0]:
            if len(group) >= batch_size:
                break
            if idx not in group:
                group.append(idx)
        # 计算每个点到其他点的平均距离，找到最中心的点
        group_coords = spatial_coords[group]
        mean_distances = np.mean(
            np.sqrt(((group_coords[:, None, :] - group_coords[None, :, :]) ** 2).sum(axis=2)),
            axis=1,
        )
        most_central_idx = np.argmin(mean_distances)  # 找到平均距离最小的点索引
        group.insert(0, group.pop(most_central_idx))  # 将最中心的点移到第一个位置
        batches.append(group)

    return batches


def draw_mini(combined_adata, all_batches, slide_id):
    plt.clf()
    plt.close()
    # 提取所有 spot 的空间坐标
    array_row = combined_adata.obs['array_row']
    array_col = combined_adata.obs['array_col']

    # 创建一个 DataFrame 存储每个 spot 的空间坐标和对应的 mini-batch 信息
    spots_data = []
    for batch_idx, batch in enumerate(all_batches):
        for spot in batch:
            x = int(array_row[spot])
            y = int(array_col[spot])# 提取空间坐标
            spots_data.append({'x': x, 'y': y, 'batch': batch_idx})

    # 转换为 DataFrame
    spots_df = pd.DataFrame(spots_data)
    center_spots_df = spots_df.groupby('batch').first().reset_index()
    # 生成高对比度的颜色调色板（使用 seaborn 的离散调色板）
    num_batches = len(all_batches)
    palette = sns.color_palette("tab20", num_batches)  # 可选择其他调色板，如 "tab20b", "hsv", 等

    # 随机打乱颜色分配顺序，避免相邻 batch 颜色接近
    np.random.seed(42)  # 固定随机种子以保证可重复性
    shuffled_colors = np.random.permutation(palette)

    # 创建一个颜色字典，将每个 batch 映射到唯一的颜色
    batch_colors = {batch: shuffled_colors[batch] for batch in range(num_batches)}

    # 绘制每个 spot 的空间坐标，使用对应的 mini-batch 颜色
    plt.figure(figsize=(12, 10))
    for batch, group in spots_df.groupby('batch'):
        plt.scatter(group['x'], group['y'], color=batch_colors[batch], label=f'Mini-batch {batch}', s=10)
    
    for batch, group in center_spots_df.groupby('batch'):
        plt.scatter(group['x'], group['y'], marker='*', color=batch_colors[batch], s=100) 

    plt.title('Spatial Distribution of Mini-batches')
    plt.xlabel('X Coordinate')
    plt.ylabel('Y Coordinate')
    plt.legend(loc='upper left', bbox_to_anchor=(1.05, 1), fontsize='small')  # 将图例放在右侧
    plt.tight_layout()
    save_dir = 'distribution'
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(f'{save_dir}/mini_batch_distribution_{slide_id}.png')


def filter_batches_by_distance(all_batches, combined_adata, max_distance=10):
    """
    筛选掉任意两个点之间距离超过 max_distance 的 mini-batches。
    
    参数：
        all_batches: list，每个元素是一个 mini-batch 的索引列表。
        spatial_coords: ndarray，所有 spot 的空间坐标，形状为 (n_spots, 2)。
        max_distance: float，最大允许的距离。
    
    返回：
        filtered_batches: list，筛选后的 mini-batches。
    """
    filtered_batches = []
    max_distance2 = 0
    for batch in all_batches:
        # 提取该 mini-batch 的空间坐标
        array_row = combined_adata.obs['array_row']
        array_col = combined_adata.obs['array_col']
        coords = np.column_stack((array_row[batch], array_col[batch])).astype(int)
        
        # 计算该 mini-batch 中所有点之间的欧几里得距离
        distances = np.sqrt(((coords[:, None, :] - coords[None, :, :]) ** 2).sum(axis=2))
        
        # 如果所有点对的距离都小于等于 max_distance，则保留该 batch
        if distances.max() > max_distance2:
            max_distance2 = distances.max()
        if np.all(distances <= max_distance):
            filtered_batches.append(batch)
    
    return filtered_batches


def get_safe_region(x_center, y_center, patch_size, image_w, image_h):
    valid_grids = []
    
    # 预计算所有可能的九宫格位置
    for grid_id in range(9):
        row = grid_id // 3
        col = grid_id % 3
        
        # 计算候选区域坐标
        x1 = x_center - (col * 2 + 1) * patch_size
        y1 = y_center - (row * 2 + 1) * patch_size
        x2 = x1 + 6 * patch_size
        y2 = y1 + 6 * patch_size
        
        # 边界合规性检查
        if x1 >= 0 and y1 >= 0 and x2 <= image_w and y2 <= image_h:
            valid_grids.append(grid_id)
    
    # 安全选择逻辑
    if not valid_grids:
        # 应急处理：返回最大可生成区域
        safe_x1 = max(0, x_center - 3*patch_size)
        safe_y1 = max(0, y_center - 3*patch_size)
        safe_x2 = min(image_w, x_center + 3*patch_size)
        safe_y2 = min(image_h, y_center + 3*patch_size)
        return (safe_x1, safe_y1), (safe_x2, safe_y2), -1  # -1表示非常规区域
    
    # 随机选择有效位置
    chosen = random.choice(valid_grids)
    row = chosen // 3
    col = chosen % 3
    
    # 最终坐标计算
    final_x1 = x_center - (col * 2 + 1) * patch_size
    final_y1 = y_center - (row * 2 + 1) * patch_size
    return (int(final_x1), int(final_y1)), (int(final_x1+6*patch_size), int(final_y1+6*patch_size)), chosen


def adjust_crop(top_left, bottom_right, img_width, img_height, patch_size):
    x_min, y_min = top_left
    x_max, y_max = bottom_right

    # 确保 x_min 和 y_min 不小于0
    if x_min < 0:
        x_max = int(min(x_max - (x_min - 0), img_width))
        x_min = 0
    if y_min < 0:
        y_max = int(min(y_max - (y_min - 0), img_height))
        y_min = 0

    # 确保 x_max 和 y_max 不超出图像宽度和高度
    if x_max > img_width:
        x_min = int(max(x_min - (x_max - img_width), 0))
        x_max = img_width
    if y_max > img_height:
        y_min = int(max(y_min - (y_max - img_height), 0))
        y_max = img_height


    return (x_min, y_min), (x_max, y_max)