import random
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import numpy as np
import json
import matplotlib.pyplot as plt
import argparse
import os
import torch
from torch.utils.data import Dataset

class VQADataset(Dataset):
    def __init__(
            self, image_dir_path, data_path, is_train, dataset_name, max_samples=None
    ):
        # 加载数据文件
        with open(data_path, "r") as f:
            data_dict = json.load(f)

        # 提取数据列表
        self.data = data_dict["data"]
        if max_samples is not None:
            self.data = self.data[:max_samples]

        self.image_dir_path = image_dir_path
        self.is_train = is_train
        self.dataset_name = dataset_name
        if self.dataset_name in {"vqav2", "ok_vqa"}:
            self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1]
            assert self.img_coco_split in {"train2014", "val2014", "test2015"}

    def __len__(self):
        return len(self.data)

    def get_img_path(self, item):
        if self.dataset_name in {"vqav2", "ok_vqa"}:
            return os.path.join(
                self.image_dir_path,
                f"COCO_{self.img_coco_split}_{item['image_id']:012d}.jpg"
                if self.is_train
                else f"COCO_{self.img_coco_split}_{item['image_id']:012d}.jpg",
            )
        elif self.dataset_name == "vizwiz":
            return os.path.join(self.image_dir_path, item["image_id"])
        elif self.dataset_name == "textvqa":
            return os.path.join(self.image_dir_path, f"{item['image_id']}.jpg")
        else:
            raise Exception(f"Unknown VQA dataset {self.dataset_name}")

    def __getitem__(self, idx):
        item = self.data[idx]
        img_path = self.get_img_path(item)
        results = {
            "image_path": img_path,
            "question": item["question"],
            "question_id": item["question_id"],
            "image_id": item["image_id"],
        }
        # 检查答案是否存在
        if "answers" in item:
            results["answers"] = item["answers"]
        return results

def fix_seed(seed):
    # random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

def load_json_data(file_path):
    """加载JSON文件"""
    with open(file_path, 'r') as f:
        return json.load(f)

def parse_arguments():
    parser = argparse.ArgumentParser(description="Zero-shot-CoT for ScienceQA Dataset with Train Filtering")
    parser.add_argument(
        "--task", type=str, default="textvqa",
        help="dataset used for experiment"
    )
    parser.add_argument(
        "--max_ra_len", type=int, default=5,
        help="maximum number of reasoning chains"
    )
    parser.add_argument(
        "--problems_file", type=str, default="/home/test/yxl/MCoT/data/textvqa/TextVQA_0.5.1_train.json",
        help="path to the problems json file"
    )
    parser.add_argument(
        "--images_file", type=str, default="/home/test/yxl/MCoT/data/textvqa/images",
        help="path to the problems json file"
    )
    parser.add_argument(
        "--demo_save_dir", type=str, default="/home/test/yxl/MCoT/textvqa/tool",
        help="where to save the constructed demonstrations"
    )
    parser.add_argument("--random_seed", type=int, default=192, help="random seed")
    parser.add_argument(
        "--encoder", type=str, default="all-MiniLM-L6-v2",
        help="which sentence-transformer encoder for clustering"
    )
    parser.add_argument(
        "--sampling", type=str, default="center",
        help="whether to sample the cluster center first"
    )
    parser.add_argument(
        "--debug", type=bool, default=True, help="debug mode"
    )
    parser.add_argument(
        "--num_clusters", type=int, default=4,
        help="number of clusters for KMeans"
    )
    args = parser.parse_args()
    return args


def main():
    args = parse_arguments()
    fix_seed(args.random_seed)
    encoder = SentenceTransformer(args.encoder)

    task = args.task
    problems_file = args.problems_file
    images_file = args.images_file
    save_file = args.demo_save_dir
    max_ra_len = args.max_ra_len
    num_clusters = args.num_clusters


    # 确保保存目录存在
    os.makedirs(save_file, exist_ok=True)

    problems = VQADataset(
        image_dir_path=images_file,
        data_path=problems_file,
        is_train=True,
        dataset_name="textvqa",
        max_samples=1000
    )

    problems = {item['question_id']: item for item in problems}


    # 构建语料库和数据列表
    corpus = []
    questions = []
    gold_ans = []
    qids = []  # 保存原始qid

    selected_pids = list(problems.keys())

    for qid in selected_pids:
        if qid not in problems:
            print(f"Warning: QID {qid} not found in problems data. Skipping...")
            continue

        problem = problems[qid]

        full_question = f"Q: {problem['question']}\n"
        questions.append(full_question)



        # 保存答案
        gold_ans.append(problem['answers'])


        corpus_text = f"{problem['question']}"
        corpus.append(corpus_text)

        # 保存选项用于后续处理
        qids.append(qid)  # 保存原始qid

    print(f"Total samples after filtering: {len(corpus)}")

    # 编码语料库
    print("Encoding corpus with SentenceTransformer...")
    corpus_embeddings = encoder.encode(corpus)

    # 执行KMeans聚类
    print(f"Performing KMeans clustering with {num_clusters} clusters...")
    clustering_model = KMeans(n_clusters=num_clusters, random_state=args.random_seed)
    clustering_model.fit(corpus_embeddings)
    cluster_assignment = clustering_model.labels_

    # 组织聚类结果
    clustered_sentences = [[] for i in range(num_clusters)]
    clustered_dists = [[] for i in range(num_clusters)]
    clustered_idx = [[] for i in range(num_clusters)]
    dist = clustering_model.transform(corpus_embeddings)

    for sentence_id, cluster_id in enumerate(cluster_assignment):
        clustered_sentences[cluster_id].append(corpus[sentence_id])
        clustered_dists[cluster_id].append(dist[sentence_id][cluster_id])
        clustered_idx[cluster_id].append(sentence_id)

    # 构建演示示例
    demos = []
    print("Building demonstration examples...")

    for i in range(len(clustered_dists)):
        print(f"Processing Cluster {i + 1}/{num_clusters}")
        # 按距离排序，找到最接近簇中心的样本
        tmp = list(map(list, zip(range(len(clustered_dists[i])), clustered_dists[i])))
        top_min_dist = sorted(tmp, key=lambda x: x[1], reverse=False)

        # 可选的随机打乱
        if not args.sampling == "center":
            random.shuffle(top_min_dist)

        for element in top_min_dist:
            min_idx = element[0]
            idx_in_corpus = clustered_idx[i][min_idx]

            # 提取相关信息
            c_question = questions[idx_in_corpus]
            c_gold_ans = gold_ans[idx_in_corpus]
            c_qid = qids[idx_in_corpus]  # 使用保存的原始qid

            demo_element = {
                "qid": c_qid,
                "question": c_question,
                "gold_ans": c_gold_ans,
                "cluster_id": i
            }
            demos.append(demo_element)
            print(f"Added demo from Cluster {i + 1} with QID: {c_qid}")
            print(f"Question: {c_question}")
            print(f"Answer: {c_gold_ans}")
            print("---")
            break  # 每个簇只取一个样本

    # 保存演示示例
    demo_output_file = os.path.join(save_file, f"demos_train.json")
    print(f"Saving {len(demos)} demonstrations to {demo_output_file}")
    with open(demo_output_file, 'w', encoding="utf-8") as write_f:
        json.dump({"demos": demos}, write_f, indent=4, ensure_ascii=False)

    # 可视化聚类结果
    print("Visualizing clustering results...")
    y_km = clustering_model.fit_predict(corpus_embeddings)
    pca_model = PCA(n_components=2, random_state=args.random_seed)
    transformed = pca_model.fit_transform(corpus_embeddings)
    centers = pca_model.transform(clustering_model.cluster_centers_)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(x=transformed[:, 0], y=transformed[:, 1],
                          c=y_km, s=50, cmap=plt.cm.Paired, alpha=0.4)
    plt.scatter(centers[:, 0], centers[:, 1],
                s=250, marker='*', label='centroids',
                edgecolor='black',
                c=np.arange(0, num_clusters), cmap=plt.cm.Paired)
    plt.xticks([])
    plt.yticks([])
    plt.title(f"KMeans Clustering for OKVQA train Split (n={num_clusters})")
    plt.colorbar(scatter, label='Cluster Label')

    # 保存聚类可视化
    vis_output_file = os.path.join(save_file, f"clustering_train.png")
    plt.savefig(vis_output_file, dpi=600, bbox_inches='tight')
    print(f"Clustering visualization saved to {vis_output_file}")


if __name__ == "__main__":
    main()