import random
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import numpy as np
import json
import matplotlib.pyplot as plt
import argparse
import os
import torch

def fix_seed(seed):
    # random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

def parse_arguments():
    parser = argparse.ArgumentParser(description="Zero-shot-CoT for ScienceQA Dataset with Train Filtering")
    parser.add_argument(
        "--task", type=str, default="scienceqa",
        help="dataset used for experiment"
    )
    parser.add_argument(
        "--max_ra_len", type=int, default=5,
        help="maximum number of reasoning chains"
    )
    parser.add_argument(
        "--problems_file", type=str, default="/home/test/yxl/MCoT/data/scienceqa/problems.json",
        help="path to the problems json file"
    )
    parser.add_argument(
        "--pid_splits_file", type=str, default="/home/test/yxl/MCoT/data/scienceqa/pid_splits.json",
        help="path to the pid splits json file containing train/test splits"
    )
    parser.add_argument(
        "--demo_save_dir", type=str, default="/home/test/yxl/MCoT/sqa/tool",
        help="where to save the constructed demonstrations"
    )
    parser.add_argument("--random_seed", type=int, default=192, help="random seed")
    parser.add_argument(
        "--encoder", type=str, default="all-MiniLM-L6-v2",
        help="which sentence-transformer encoder for clustering"
    )
    parser.add_argument(
        "--sampling", type=str, default="center",
        help="whether to sample the cluster center first"
    )
    parser.add_argument(
        "--debug", type=bool, default=True, help="debug mode"
    )
    parser.add_argument(
        "--num_clusters", type=int, default=4,
        help="number of clusters for KMeans"
    )
    parser.add_argument(
        "--split_type", type=str, default="train",
        choices=["train", "test", "all"],
        help="which split to use for clustering"
    )
    args = parser.parse_args()
    return args


def main():
    args = parse_arguments()
    fix_seed(args.random_seed)
    encoder = SentenceTransformer(args.encoder)

    task = args.task
    problems_file = args.problems_file
    pid_splits_file = args.pid_splits_file
    save_file = args.demo_save_dir
    max_ra_len = args.max_ra_len
    num_clusters = args.num_clusters
    split_type = args.split_type

    # 确保保存目录存在
    os.makedirs(save_file, exist_ok=True)

    # 读取数据集
    print(f"Reading problems from {problems_file}")
    with open(problems_file, "r", encoding="utf-8") as fp:
        problems = json.load(fp)


    print(f"Reading splits from {pid_splits_file}")
    with open(pid_splits_file, "r", encoding="utf-8") as fp:
        pid_splits = json.load(fp)


    # 过滤指定split的样本ID
    if split_type == "all":
        selected_pids = list(problems.keys())
    else:
        selected_pids = pid_splits.get(split_type, [])
        print(f"Using {len(selected_pids)} samples from {split_type} split")

    if not selected_pids:
        raise ValueError(f"No samples found for split type: {split_type}")

    # 构建语料库和数据列表
    corpus = []
    questions = []
    solutions = []
    lectures = []
    choices = []
    gold_ans = []
    qids = []  # 保存原始qid

    for qid in selected_pids:
        if qid not in problems:
            print(f"Warning: QID {qid} not found in problems data. Skipping...")
            continue

        problem = problems[qid]


        full_question = f"Q: {problem['question']}\nChoices: {', '.join(problem['choices'])}\nLecture: {problem['lecture']}"
        questions.append(full_question)

        # 添加solution作为推理依据
        solutions.append(problem['solution'])


        # 保存答案
        gold_ans.append(problem['answer'])

        corpus_text = f"{problem['question']}{' '.join(problem['choices'])} {problem['lecture']} {problem['solution']}"
        corpus.append(corpus_text)

        # 保存选项用于后续处理
        choices.append(problem['choices'])
        lectures.append(problem['lecture'])
        qids.append(qid)  # 保存原始qid

    print(f"Total samples after filtering: {len(corpus)}")

    # 编码语料库
    print("Encoding corpus with SentenceTransformer...")
    corpus_embeddings = encoder.encode(corpus)

    # 执行KMeans聚类
    print(f"Performing KMeans clustering with {num_clusters} clusters...")
    clustering_model = KMeans(n_clusters=num_clusters, random_state=args.random_seed)
    clustering_model.fit(corpus_embeddings)
    cluster_assignment = clustering_model.labels_

    # 组织聚类结果
    clustered_sentences = [[] for i in range(num_clusters)]
    clustered_dists = [[] for i in range(num_clusters)]
    clustered_idx = [[] for i in range(num_clusters)]
    dist = clustering_model.transform(corpus_embeddings)

    for sentence_id, cluster_id in enumerate(cluster_assignment):
        clustered_sentences[cluster_id].append(corpus[sentence_id])
        clustered_dists[cluster_id].append(dist[sentence_id][cluster_id])
        clustered_idx[cluster_id].append(sentence_id)

    # 构建演示示例
    demos = []
    print("Building demonstration examples...")

    for i in range(len(clustered_dists)):
        print(f"Processing Cluster {i + 1}/{num_clusters}")
        # 按距离排序，找到最接近簇中心的样本
        tmp = list(map(list, zip(range(len(clustered_dists[i])), clustered_dists[i])))
        top_min_dist = sorted(tmp, key=lambda x: x[1], reverse=False)

        # 可选的随机打乱
        if not args.sampling == "center":
            random.shuffle(top_min_dist)

        for element in top_min_dist:
            min_idx = element[0]
            idx_in_corpus = clustered_idx[i][min_idx]

            # 提取相关信息
            c_question = questions[idx_in_corpus]
            c_solution = solutions[idx_in_corpus].strip()
            c_choices = choices[idx_in_corpus]
            c_lecture = lectures[idx_in_corpus]
            c_gold_ans = gold_ans[idx_in_corpus]
            c_qid = qids[idx_in_corpus]  # 使用保存的原始qid

            # 处理solution格式
            c_solution = c_solution.replace("\n\n", "\n").replace("\n", " ").strip()
            c_solution = " ".join(c_solution.split())

            # 检查solution长度
            if c_solution and len(c_solution.split("\n")) <= max_ra_len and c_solution[-1] in [".", "!", "?"]:
                demo_element = {
                    "qid": c_qid,
                    "question": c_question,
                    "choices": c_choices,
                    "lecture": c_lecture,
                    "rationale": c_solution,
                    "gold_ans": c_gold_ans,
                    "cluster_id": i,
                    "split": split_type
                }
                demos.append(demo_element)
                print(f"Added demo from Cluster {i + 1} with QID: {c_qid}")
                print(f"Question: {c_question}")
                print(f"Rationale: {c_solution}")
                print(f"Answer: {c_gold_ans}")
                print("---")
                break  # 每个簇只取一个样本

    # 保存演示示例
    demo_output_file = os.path.join(save_file, f"demos_{split_type}.json")
    print(f"Saving {len(demos)} demonstrations to {demo_output_file}")
    with open(demo_output_file, 'w', encoding="utf-8") as write_f:
        json.dump({"demos": demos}, write_f, indent=4, ensure_ascii=False)

    # 可视化聚类结果
    print("Visualizing clustering results...")
    y_km = clustering_model.fit_predict(corpus_embeddings)
    pca_model = PCA(n_components=2, random_state=args.random_seed)
    transformed = pca_model.fit_transform(corpus_embeddings)
    centers = pca_model.transform(clustering_model.cluster_centers_)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(x=transformed[:, 0], y=transformed[:, 1],
                          c=y_km, s=50, cmap=plt.cm.Paired, alpha=0.4)
    plt.scatter(centers[:, 0], centers[:, 1],
                s=250, marker='*', label='centroids',
                edgecolor='black',
                c=np.arange(0, num_clusters), cmap=plt.cm.Paired)
    plt.xticks([])
    plt.yticks([])
    plt.title(f"KMeans Clustering for ScienceQA {split_type.upper()} Split (n={num_clusters})")
    plt.colorbar(scatter, label='Cluster Label')

    # 保存聚类可视化
    vis_output_file = os.path.join(save_file, f"clustering_{split_type}.png")
    plt.savefig(vis_output_file, dpi=600, bbox_inches='tight')
    print(f"Clustering visualization saved to {vis_output_file}")


if __name__ == "__main__":
    main()