import json
import os
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import argparse
import torch
import csv
import re
import ast
from LLM_prompt import *
import numpy as np
import pickle
import cv2
from PIL import Image
import torchvision.transforms as transforms
import sys
sys.path.append('/mnt/nvme_share/wuwl/project/CARZero-main/')
import CARZero
import types
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import scipy.stats as stats
from nltk.tokenize import RegexpTokenizer



def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--LLM_name_or_path', type=str, default=None)
    parser.add_argument('--tokenizer_path', type=str, default=None)
    parser.add_argument('--model_name', type=str, default=None)
    parser.add_argument('--ckpt_path', type=str, default=None)
    args = parser.parse_args()
    return args

def max_length_string(strings):
    if not strings:
        return 0
    max_length = max(len(s) for s in strings)
    return max_length

def inference(messages, model, tokenizer, max_length):
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
    outputs = model.generate(
        input_ids,
        max_new_tokens=max_length,
        eos_token_id=terminators,
        do_sample=False,
        temperature=0.6,
        top_p=0.9,
    )
    response = outputs[0][input_ids.shape[-1]:]
    response = tokenizer.decode(response, skip_special_tokens=True)
    return response

def append_to_csv(file_path, row):
    with open(file_path, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(row)

def parse_string_to_list(string):
    try:
        string = re.sub(r"(?<=\[|,)\s*'(.*?)'\s*(?=,|\])", r'"\1"', string)
        result = ast.literal_eval(string)
        if isinstance(result, list):
            return result
        else:
            raise ValueError("The parsed result is not a list.")
    except (SyntaxError, ValueError) as e:
        print(f"Error parsing string: {e}")
        return None

def adjust_negative_label(input_string):
    if input_string == '25':
        return input_string
    elif input_string == 'X':
        return input_string
    else:
        elements = input_string.split(',')
        processed_elements = [element + '-' for element in elements]
        output_string = ','.join(processed_elements)
        return output_string

def adjust_positive_label(n):
    mapping = {
        1: '1+', 2: '2+', 3: '3+', 4: '4+', 5: '5+',
        6: '6+', 7: '7+', 8: '8+', 9: '9+', 10: '10+',
        11: '11+', 12: '12+', 13: '13+', 14: '14+',
        15: '15+', 16: '16+', 17: '15+', 18: '17+',
        19: '18+', 20: '19+', 21: '20+', 22: '21+',
        23: '22+', 24: '23+', 25: '24+', 26: '25'
    }
    return mapping.get(n, "输入超出范围")

def check_validity(lst):
    sign_dict = {}
    split_elements = set()
    invalid_number = []
    invalid_elements_oral = []
    for item in lst:
        elements = item.split(', ')
        for element in elements:
            split_elements.add(element)
    for item in split_elements:
        if item in ['25', '26']:
            continue
        number = item[:-1]
        sign = item[-1]
        if number in sign_dict:
            if sign_dict[number] != sign:
                invalid_number.append(number)
        else:
            sign_dict[number] = sign
    for original_item in lst:
        for number in invalid_number:
            if number in original_item:
                invalid_elements_oral.append(original_item)
                break
    if len(invalid_elements_oral) == 0:
        return True
    else:
        return invalid_elements_oral

def find_and_remove_indices(lst, elements):
    result = []
    list_temp = []
    for item in lst:
        list_temp.append(item)
    for element in elements:
        index = list_temp.index(element)
        result.append(index)
        list_temp[index] = None
    return result

def get_imgs(img_path, transform=None):
    x = cv2.imread(str(img_path), 0)
    x = resize_img(x, 256)
    img = Image.fromarray(x).convert("RGB")
    if transform is not None:
        img = transform(img)
    return img

def resize_img(img, scale):
    size = img.shape
    max_dim = max(size)
    max_ind = size.index(max_dim)
    if max_ind == 0:
        wpercent = scale / float(size[0])
        hsize = int((float(size[1]) * float(wpercent)))
        desireable_size = (scale, hsize)
    else:
        hpercent = scale / float(size[1])
        wsize = int((float(size[0]) * float(hpercent)))
        desireable_size = (wsize, scale)
    resized_img = cv2.resize(img, desireable_size[::-1], interpolation=cv2.INTER_AREA)
    if max_ind == 0:
        pad_size = scale - resized_img.shape[1]
        left = int(np.floor(pad_size / 2))
        right = int(np.ceil(pad_size / 2))
        top = int(0)
        bottom = int(0)
    else:
        pad_size = scale - resized_img.shape[0]
        top = int(np.floor(pad_size / 2))
        bottom = int(np.ceil(pad_size / 2))
        left = int(0)
        right = int(0)
    resized_img = np.pad(resized_img, [(top, bottom), (left, right)], "constant", constant_values=0)
    return resized_img

def build_transformation():
    t = []
    t.append(transforms.RandomCrop(224))
    t.append(transforms.RandomHorizontalFlip(0.3))
    t.append(transforms.RandomAffine(30, translate=[0.1, 0.1], scale=[0.9, 1.1]))
    t.append(transforms.ColorJitter(brightness=[0.8, 1.2], contrast=[0.8, 1.2]))
    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(mean=[0.4978], std=[0.2449]))
    return transforms.Compose(t)

def merge_and_deduplicate(input_list):
    new_list = []
    for list_sample in input_list:
        split_elements = set()
        for item in list_sample:
            elements = str(item).split(', ')
            for element in elements:
                if '+' not in element and '-' not in element:
                    continue
                else:
                    split_elements.add(element)
        new_list.append(list(split_elements))
    return new_list

def count_elements(lst):
    counter = Counter()
    for row in lst:
        for element in row:
            number = int(element[:-1])
            if 1 <= number <= 24:
                counter[number] += 1
    sorted_counts = sorted(counter.items(), key=lambda x: x[1], reverse=True)
    return sorted_counts

def plot_tsne2d(features, labels, output_name, title=None):
    tsne = TSNE(n_components=2, init='pca', random_state=0, verbose=1)
    tsne_features = tsne.fit_transform(features)
    x_min, x_max = np.min(tsne_features, 0), np.max(tsne_features, 0)
    embedded = (tsne_features - x_min) / (x_max - x_min)
    # hex_colors = ['tomato', 'salmon', 'olivedrab', 'yellowgreen', 'saddlebrown', 'chocolate', 'royalblue', 'cornflowerblue']
    hex_colors = ['tomato', 'royalblue']
    fig = plt.figure()
    ax = fig.add_subplot(111)
    plt.title(title)
    labels = np.array(labels)
    unique_labels = np.unique(np.where(labels)[1])
    for label in unique_labels:
        label_indices = np.where(labels[:, label])[0]
        label_features = embedded[label_indices]
        ax.scatter(label_features[:, 0], label_features[:, 1], c=hex_colors[label], marker=".", label=f"Label {label}")
    legend = ax.legend(title='Labels', loc="upper right", bbox_to_anchor=(1.0, 1.0))
    plt.savefig(output_name)
    plt.show()
    plt.close()

def calculate_distribution(features, num_bins=50):
    hist, bin_edges = np.histogram(features, bins=num_bins, density=True)
    prob_distribution = hist / np.sum(hist)
    return prob_distribution, bin_edges


def fisher_score(features_class_1, features_class_2):
    global_feature_list = features_class_1 + features_class_2
    global_feature = np.concatenate(global_feature_list, axis=0)
    overall_mean = np.mean(global_feature, axis=0)
    S_B = np.zeros(global_feature.shape[1])  # 针对每个特征
    S_W = np.zeros(global_feature.shape[1])  # 针对每个特征
    features_class_1 = np.concatenate(features_class_1, axis=0)
    n_cls_1 = features_class_1.shape[0]
    class_mean_1 = np.mean(features_class_1, axis=0)
    S_B += n_cls_1 * (class_mean_1 - overall_mean) ** 2
    S_W += np.sum((features_class_1 - class_mean_1) ** 2, axis=0)
    features_class_2 = np.concatenate(features_class_2, axis=0)
    n_cls_2 = features_class_2.shape[0]
    class_mean_2 = np.mean(features_class_2, axis=0)
    S_B += n_cls_2 * (class_mean_2 - overall_mean) ** 2
    S_W += np.sum((features_class_2 - class_mean_2) ** 2, axis=0)
    fisher_score_per_feature = S_B / S_W
    fisher_score_overall = np.mean(fisher_score_per_feature)
    return fisher_score_overall

def process_text(text, device, args):
    if type(text) == str:
        text = [text]
    processed_text_tensors = []
    for t in text:
        t = t.replace("\n", " ")
        splitter = re.compile("[0-9]+\.")
        captions = splitter.split(t)
        captions = [point.split(".") for point in captions]
        captions = [sent for point in captions for sent in point]
        all_sents = []
        for t in captions:
            t = t.replace("\ufffd\ufffd", " ")
            tokenizer = RegexpTokenizer(r"\w+")
            tokens = tokenizer.tokenize(t.lower())
            if len(tokens) <= 1:
                continue
            included_tokens = []
            for t in tokens:
                t = t.encode("ascii", "ignore").decode("ascii")
                if len(t) > 0:
                    included_tokens.append(t)
            all_sents.append(" ".join(included_tokens))
        t = " ".join(all_sents)
        self_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
        idxtoword = {v: k for k, v in self_tokenizer.get_vocab().items()}
        text_tensors = self_tokenizer(
            t,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=100,
            )
        text_tensors["sent"] = [idxtoword[ix] for ix in text_tensors["input_ids"][0].tolist()]
        processed_text_tensors.append(text_tensors)
    caption_ids = torch.stack([x["input_ids"] for x in processed_text_tensors])
    attention_mask = torch.stack([x["attention_mask"] for x in processed_text_tensors])
    token_type_ids = torch.stack([x["token_type_ids"] for x in processed_text_tensors])
    if len(text) == 1:
        caption_ids = caption_ids.squeeze(0).to(device)
        attention_mask = attention_mask.squeeze(0).to(device)
        token_type_ids = token_type_ids.squeeze(0).to(device)
    else:
        caption_ids = caption_ids.squeeze().to(device)
        attention_mask = attention_mask.squeeze().to(device)
        token_type_ids = token_type_ids.squeeze().to(device)
    cap_lens = []
    for txt in text:
        cap_lens.append(len([w for w in txt if not w.startswith("[")]))
    return {
        "caption_ids": caption_ids,
        "attention_mask": attention_mask,
        "token_type_ids": token_type_ids,
        "cap_lens": cap_lens,
    }


if __name__ == '__main__':

    args = parse_args()

    file_input_path = '/mnt/nvme_share/zhux/dataset/MIMIC_caption/annotation.json'
    file_output_path = '/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC_test/mimic-cxr-test.csv'

    # Read MIMIC Test Data
    if not os.path.exists(file_output_path):
        with open(file_input_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
        data = data['test']
        mimic_cxr_test_data = []
        for sample in data:
            prefix = '/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC-CXR-JPG/2.0.0/files/'
            image_path = prefix + sample['image_path'][0]
            report = sample['report']
            if os.path.exists(image_path):
                mimic_cxr_test_data.append([image_path, report])
            else:
                continue
        pd.DataFrame(mimic_cxr_test_data).to_csv(file_output_path, index=False, header=['image_path', 'report'])
    else:
        mimic_cxr_test_data = pd.read_csv(file_output_path).values.tolist()
    report_list = []
    for idx in range(len(mimic_cxr_test_data)):
        captions = mimic_cxr_test_data[idx][1]
        captions = captions.replace("\n", "")
        report_list.append(captions)
    max_length = max_length_string(report_list) + 800
    img_path_list = []
    for idx in range(len(mimic_cxr_test_data)):
        img_path = mimic_cxr_test_data[idx][0]
        img_path_list.append(img_path)

    # LLM Init
    model = AutoModelForCausalLM.from_pretrained(args.LLM_name_or_path, trust_remote_code=True, torch_dtype=torch.float16)
    model.cuda()
    tokenizer = AutoTokenizer.from_pretrained(
        args.LLM_name_or_path,
        model_max_length=max_length,
        use_fast=False,
        trust_remote_code=True
    )

    # LLM Cut Sent
    cut_sent_output_path = '/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC_test/LLM_cut_sent.csv'
    if os.path.exists(cut_sent_output_path):
        history = pd.read_csv(cut_sent_output_path, header=None)
        cache = len(history)
    else:
        cache = 0
    for idx in tqdm(range(0, len(report_list)), desc='Cut Sent'):
        if idx < cache:
            continue
        sample = report_list[idx]
        data_sent = prompt_data(sample)
        output = inference(data_sent, model, tokenizer, max_length)
        output_list = parse_string_to_list(output)
        append_to_csv(cut_sent_output_path, [str(output_list)])

    # LLM Creat Label
    sent_output_path = '/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC_test/LLM_sent_list_final.xlsx'
    label_output_path = '/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC_test/LLM_label_list_final.xlsx'
    cut_report_list = pd.read_csv(cut_sent_output_path).values.tolist()
    if os.path.exists(label_output_path):
        history = pd.read_excel(label_output_path, header=None)
        cache = len(history)
    else:
        cache = 0
    for idx in tqdm(range(0, len(cut_report_list)), desc='Creat Label'):
        if idx < cache:
            continue
        sample = cut_report_list[idx]
        label_list = []
        sent_list = parse_string_to_list(sample[0])
        for i in range(0, len(sent_list)):
            sent = sent_list[i]
            data_sent = prompt_1_data(sent)
            data_sent = inference(data_sent, model, tokenizer, max_length)
            if data_sent != 'X':
                output_temp = prompt_2_data(data_sent)
                label = inference(output_temp, model, tokenizer, max_length)
                if '4' in label:
                    output_temp = prompt_3_data(output_temp)
                    label = inference(output_temp, model, tokenizer, max_length)
            else:
                label = 'X'
            if label == 'X':
                pass
            elif label != 26 and label != '26':
                label = adjust_positive_label(int(label))
            else:
                if data_sent != 'X' and 'unchanged' not in data_sent and 'no changes' not in data_sent and 'no change' not in data_sent:
                    output_temp = prompt_2_pro_data(data_sent)
                    output_sent = inference(output_temp, model, tokenizer, max_length)
                    if output_sent == '4':
                        output_temp = prompt_5_data(data_sent)
                        output_sent = inference(output_temp, model, tokenizer, max_length)
                    if output_sent == '17':
                        output_temp = prompt_4_data(data_sent)
                        output_sent = inference(output_temp, model, tokenizer, max_length)
                    if output_sent == '25':
                        output_temp = prompt_3_plus_data(data_sent)
                        output_sent = inference(output_temp, model, tokenizer, max_length)
                        if output_sent == '25':
                            output_temp = prompt_3_pro_data(data_sent)
                            output_sent = inference(output_temp, model, tokenizer, max_length)
                    label = adjust_negative_label(output_sent)
                else:
                    label = 'X'
            label_list.append(label)
            if label == 'X':
                sent_list[i] = 'X'
        validity = check_validity(label_list)
        if validity == False:
            error_index = find_and_remove_indices(label_list, validity)
            for index in error_index:
                label_list[index] = '26'
        filtered_sent_list = [element for element in sent_list if element != 'X']
        filtered_label_list = [element for element in label_list if element != 'X']
        if len(filtered_label_list) < 75:
            for _ in range(75 - len(filtered_label_list)):
                filtered_label_list.append('0')
        append_to_csv(label_output_path, filtered_label_list)
        append_to_csv(sent_output_path, [str(filtered_sent_list)])
    sent = pd.read_excel(sent_output_path, header=None).values.tolist()
    label = pd.read_excel(label_output_path, header=None).values.tolist()
    sent_list = []
    for sent_sample in tqdm(sent, desc='Loading chopped sentences'):
        sent_list.append(parse_string_to_list(sent_sample[0]))
    label_list = []
    for label_sample in tqdm(label, desc='Loading sentences labels'):
        temp = [item for item in label_sample if not isinstance(item, float)]
        temp = [element for element in temp if element != '0' and element != 0]
        label_list.append(temp)
    pickle_path = "/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC_test/captions_sent_label_with_LLM.pickle"
    if not os.path.isfile(pickle_path):
        print(f"Caption file {pickle_path} does not exit. Creating captions...")
        with open(pickle_path, "wb") as f:
            pickle.dump([sent_list, label_list], f, protocol=2)
            print("Save to: ", pickle_path)
    else:
        with open(pickle_path, "rb") as f:
            print(f"Loading captions from {pickle_path}")
            sent_list, label_list = pickle.load(f)

    # text data
    with open('/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC_test/MIMIC_multi_label_text_plus.json', 'r') as f:
        cls_prompts = json.load(f)
    processed_txt = {}
    for k, v in cls_prompts.items():
        processed_txt[k] = process_text(v, "cpu", args)
    caption_ids, attention_mask, token_type_ids = [], [], []
    for cls_name, txts in processed_txt.items():
        caption_ids.append(txts["caption_ids"])
        attention_mask.append(txts["attention_mask"])
        token_type_ids.append(txts["token_type_ids"])
    device = "cuda" if torch.cuda.is_available() else "cpu"
    caption_ids = torch.cat(caption_ids, dim=0).to(device)
    attention_mask = torch.cat(attention_mask, dim=0).to(device)
    token_type_ids = torch.cat(token_type_ids, dim=0).to(device)
    text_batch = {"caption_ids": caption_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids}

    # Model Inference
    pickle_path = f"/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC_test/{args.model_name}_feature.pickle"
    valid_label_list = ['1+', '1-', '2+', '2-', '3+', '3-', '4+', '4-', '5+', '5-', '6+', '6-', '7+', '7-', '8+',
                        '8-', '9+', '9-', '10+', '10-', '11+', '11-', '12+', '12-', '13+', '13-', '14+', '14-',
                        '15+', '15-', '16+', '16-', '17+', '17-', '18+', '18-', '19+', '19-', '20+', '20-', '21+',
                        '21-', '22+', '22-', '23+', '23-', '24+', '24-', '25', '26', 25, 26]
    if not os.path.isfile(pickle_path):
        print(f"Caption file {pickle_path} does not exit. Creating captions...")
        CARZero_model = CARZero.load_CARZero(name=args.ckpt_path, device=device)
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
        transform = build_transformation()
        text_global_list, text_local_list, img_global_list, img_local_list, SimR_list, label_list_new = [], [], [], [], [], []
        with torch.no_grad():
            CARZero_model.eval()
            for idx in tqdm(range(len(sent_list)), desc='Model Inference'):
                sent_item = sent_list[idx]
                label_item = label_list[idx]
                if len(sent_item) == 0:
                    continue
                Validation = True
                for lab in label_item:
                    if lab not in valid_label_list:
                        if ',' in lab:
                            sublab_list = lab.split(', ')
                            for sublab in sublab_list:
                                if sublab not in valid_label_list:
                                    Validation = False
                        else:
                            Validation = False
                if Validation == False:
                    continue
                label_list_new.append(label_item)
                img_path = img_path_list[idx]
                imgs = [get_imgs(img_path, transform)]
                imgs = torch.stack(imgs).to(device)
                ids, tokens, attention, cap_len = [], [], [], []
                for b in range(len(sent_item)):
                    sent = sent_item[b]
                    input_ids = tokenizer(sent, return_tensors="pt", truncation=True, padding="max_length", max_length=100)
                    x_len = len([t for t in input_ids["input_ids"][0] if t != 0])
                    ids.append(input_ids["input_ids"])
                    tokens.append(input_ids["token_type_ids"])
                    attention.append(input_ids["attention_mask"])
                    cap_len.append(x_len)
                sorted_cap_lens, sorted_cap_indices = torch.sort(torch.tensor(cap_len), 0, True)
                ids = torch.stack(ids).squeeze(1)
                tokens = torch.stack(tokens).squeeze(1)
                attention = torch.stack(attention).squeeze(1)
                caption_ids = ids[sorted_cap_indices].to(device)
                token_type_ids = tokens[sorted_cap_indices].to(device)
                attention_mask = attention[sorted_cap_indices].to(device)
                query_emb_l, query_emb_g, _ = CARZero_model.text_encoder_forward(caption_ids, attention_mask, token_type_ids)
                label_img_emb_l, label_img_emb_g = CARZero_model.image_encoder_forward(imgs)
                text_global_list.append(query_emb_g)
                text_local_list.append(query_emb_l)
                img_global_list.append(label_img_emb_g)
                img_local_list.append(label_img_emb_l)
                query_emb_l, query_emb_g, _ = CARZero_model.text_encoder_forward(text_batch["caption_ids"], text_batch["attention_mask"], text_batch["token_type_ids"])
                label_img_emb_l = label_img_emb_l.view(label_img_emb_l.size(0), label_img_emb_l.size(1), -1)
                label_img_emb_l = label_img_emb_l.permute(0, 2, 1)  # patch_num b dim
                query_emb_l_ = query_emb_l.view(query_emb_l.size(0), query_emb_l.size(1), -1)
                query_emb_l_ = query_emb_l_.permute(0, 2, 1)  # patch_num b dim # [97, 512, 768]
                i2t_cls = CARZero_model.fusion_module(torch.cat([label_img_emb_g.unsqueeze(1), label_img_emb_l], dim=1), query_emb_g, return_atten=False, use_MLP=False)
                t2i_cls = CARZero_model.fusion_module(torch.cat([query_emb_g.unsqueeze(1), query_emb_l_], dim=1), label_img_emb_g, return_atten=False, use_MLP=False)
                i2t_cls = i2t_cls
                t2i_cls = t2i_cls.transpose(1, 0)
                SimR = (i2t_cls + t2i_cls) / 2
                SimR_list.append(SimR)
        with open(pickle_path, "wb") as f:
            pickle.dump([text_global_list, text_local_list, img_global_list, img_local_list, SimR_list, label_list_new], f, protocol=2)
            print("Save to: ", pickle_path)
    else:
        with open(pickle_path, "rb") as f:
            print(f"Loading captions from {pickle_path}")
            text_global_list, text_local_list, img_global_list, img_local_list, SimR_list, label_list_new = pickle.load(f)

    # Output Visualization
    keep_vars = ['text_global_list', 'text_local_list', 'img_global_list', 'img_local_list', 'SimR_list', 'label_list_new', 'valid_label_list', 'args', 'tqdm', 'Counter', 'plt', 'TSNE']
    for var in list(globals().keys()):
        if var not in keep_vars and not var.startswith("__"):
            if not isinstance(globals()[var], type(sys)) and not isinstance(globals()[var], types.FunctionType):
                del globals()[var]
    label_list = merge_and_deduplicate(label_list_new)
    num_class = 1
    sorted_counts = count_elements(label_list)
    print(sorted_counts)
    img_Fisher_score, text_Fisher_score, SimR_Fisher_score, label_num = [], [], [], []
    for i in range(len(sorted_counts)):
        class_both_list, class_inter_list = [], []
        class_both_list.append(str(sorted_counts[i][0]) + '+')
        class_both_list.append(str(sorted_counts[i][0]) + '-')

        # SimR Feature and Label
        filtered_label_both_list, filtered_SimR_list = [], []
        for idx in tqdm(range(len(label_list)), desc='Sample Screening'):
            SimR = SimR_list[idx].detach().cpu().numpy()
            label = label_list[idx]
            filtered_label_both = [0] * num_class * 2
            for lab in label:
                if lab in class_both_list:
                    filtered_label_both[class_both_list.index(lab)] = 1
            if filtered_label_both != [0] * num_class * 2:
                filtered_label_both_list.append(filtered_label_both)
                if filtered_label_both == [1, 0]:
                    pos = valid_label_list.index(class_both_list[0])
                    filtered_SimR_list.append(SimR[:, pos, :])
                elif filtered_label_both == [0, 1]:
                    pos = valid_label_list.index(class_both_list[1])
                    filtered_SimR_list.append(SimR[:, pos, :])
                else:
                    filtered_SimR_list.append([])
        features_class_1, features_class_2, label_class_1, label_class_2 = [], [], [], []
        for idx in range(len(filtered_SimR_list)):
            label = filtered_label_both_list[idx]
            feature = filtered_SimR_list[idx]
            if label == [0, 1]:
                features_class_1.append(feature)
                label_class_1.append(label)
            elif label == [1, 0]:
                features_class_2.append(feature)
                label_class_2.append(label)
        filtered_SimR_list = features_class_1 + features_class_2
        label_class_list = label_class_1 + label_class_2
        SimR_features = np.concatenate(filtered_SimR_list, axis=0)

        print("Computing T-SNE Map")
        plot_output_path = f'/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC_test/{args.model_name}/label {str(sorted_counts[i][0])}.png'
        title = f'{args.model_name}  label {str(sorted_counts[i][0])}'
        # plot_tsne2d(SimR_features, label_class_list, plot_output_path, title)
        try:
            fisher_score_SimR = fisher_score(features_class_1, features_class_2)
        except:
            fisher_score_SimR = 0
        label_num.append([len(features_class_1), len(features_class_2)])

        # Image Feature and Label
        filtered_label_both_list, filtered_img_global_list = [], []
        for idx in tqdm(range(len(label_list)), desc='Sample Screening'):
            img_global = img_global_list[idx].detach().cpu().numpy()
            label = label_list[idx]
            filtered_label_both = [0] * num_class * 2
            for lab in label:
                if lab in class_both_list:
                    filtered_label_both[class_both_list.index(lab)] = 1
            if filtered_label_both != [0] * num_class * 2:
                filtered_label_both_list.append(filtered_label_both)
                filtered_img_global_list.append(img_global)
        features_class_1, features_class_2, label_class_1, label_class_2 = [], [], [], []
        for idx in range(len(filtered_img_global_list)):
            label = filtered_label_both_list[idx]
            feature = filtered_img_global_list[idx]
            if label == [0, 1]:
                features_class_1.append(feature)
                label_class_1.append(label)
            elif label == [1, 0]:
                features_class_2.append(feature)
                label_class_2.append(label)
        filtered_img_global_list = features_class_1 + features_class_2
        label_class_list = label_class_1 + label_class_2
        img_global_features = np.concatenate(filtered_img_global_list, axis=0)

        print("Computing T-SNE Map")
        plot_output_path = None
        # plot_tsne2d(img_global_features, label_class_list, plot_output_path)
        try:
            fisher_score_img = fisher_score(features_class_1, features_class_2)
        except:
            fisher_score_img = 0

        # Text Feature and Label
        filtered_label_both_list, filtered_text_global_list = [], []
        for idx in tqdm(range(len(label_list_new)), desc='Sample Screening'):
            text_global = text_global_list[idx]
            label = label_list_new[idx]
            filtered_label_both = [0] * num_class * 2
            for i in range(len(label)):
                lab = label[i]
                text_global_item = text_global[i].unsqueeze(0).detach().cpu().numpy()
                if type(lab) == int:
                    continue
                if ',' in lab:
                    sublab_list = lab.split(', ')
                    for sublab in sublab_list:
                        if sublab in class_both_list:
                            filtered_label_both[class_both_list.index(sublab)] = 1
                            filtered_label_both_list.append(filtered_label_both)
                            filtered_text_global_list.append(text_global_item)
                else:
                    if lab in class_both_list:
                        filtered_label_both[class_both_list.index(lab)] = 1
                        filtered_label_both_list.append(filtered_label_both)
                        filtered_text_global_list.append(text_global_item)
        features_class_1, features_class_2, label_class_1, label_class_2 = [], [], [], []
        for idx in range(len(filtered_text_global_list)):
            label = filtered_label_both_list[idx]
            feature = filtered_text_global_list[idx]
            if label == [0, 1]:
                features_class_1.append(feature)
                label_class_1.append(label)
            elif label == [1, 0]:
                features_class_2.append(feature)
                label_class_2.append(label)
        filtered_text_global_list = features_class_1 + features_class_2
        label_class_list = label_class_1 + label_class_2
        text_global_features = np.concatenate(filtered_text_global_list, axis=0)

        print("Computing T-SNE Map")
        plot_output_path = None
        # plot_tsne2d(text_global_features, label_class_list, plot_output_path)
        try:
            fisher_score_text = fisher_score(features_class_1, features_class_2)
        except:
            fisher_score_text = 0
        print(f"Fisher Score between class 1 and class 2: {fisher_score_img}")
        print(f"Fisher Score between class 1 and class 2: {fisher_score_text}")
        img_Fisher_score.append(fisher_score_img)
        text_Fisher_score.append(fisher_score_text)
        SimR_Fisher_score.append(fisher_score_SimR)

    # 使用 zip(*) 解包按列拆分
    column_lists = list(zip(*label_num))
    # 将每个列转换为独立的一维列表
    one_d_lists = [list(col) for col in column_lists]

    data = {
        'Img_Fisher_score': img_Fisher_score,
        'Text_Fisher_score': text_Fisher_score,
        'SimR_Fisher_score': SimR_Fisher_score,
        'Label0': one_d_lists[0],
        'Label1': one_d_lists[1]
    }
    # 将字典转换为 DataFrame
    df = pd.DataFrame(data)
    # 保存为 CSV 文件
    df.to_csv(f'/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC_test/{args.model_name}_output.csv', index=False)
    print("CSV 文件已成功保存!")