import os
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.manifold import TSNE
from dataclasses import dataclass, field
from tqdm import tqdm
import transformers
from torch.utils.data import DataLoader
import random

# データセット系（各自の実装に合わせて調整）
from load_data.preprocess import GSMData, MathData, AquaData, SVAMPData
from load_data.k_shot_dataset import KshotDataset
import calculator
from model.peft_model import MyPeftModelForCausalLM
from model.utils import model_name_mapping

# 3D 用の Arrow3D クラス（mpl_toolkits.mplot3d.proj3d, FancyArrowPatch を利用）
from mpl_toolkits.mplot3d import Axes3D, proj3d
import matplotlib.patches as mpatches
from utils import separate_double_and_revision_with_length
import argparse
import json
import pandas as pd

torch.backends.cudnn.benchmark = True

# プロンプトテンプレート： GSM8K の各問題に対して Chain-of-Thought を生成する
GSMK_QUERY_TEMPLATE = """
Solve the following math problem efficiently and clearly.  
The last line of your response should be of the following format: 'The answer is: ANSWER.' (without quotes) where ANSWER is just the final number or expression that solves the problem.

{Question}
""".strip()

@dataclass
class DataArguments:
    dataset: str = "gsm8k"  # "gsm8k", "math", "aqua", "svamp"
    batch_size: int = 1
    use_demonstrations: bool = False
    demo_selection: str = "uniform"
    candidate_size: int = 100
    k_shot: int = 4
    seed: int = 42
    num_test: int = 100  # 可視化するテスト件数（必要に応じて調整）
    prompt_template: str = None
    embedding_model_name: str = 'all-mpnet-base-v2'

def generate_chain_of_thought(prompt, model, tokenizer, max_new_tokens=256, temperature=0.7):
    """
    プロンプトを入力し、Chain-of-Thought の回答を生成する。
    生成結果にプロンプトが含まれている場合はその部分を削除して返す。
    """
    inputs = tokenizer(prompt, return_tensors="pt")
    device = next(model.parameters()).device
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device) if "attention_mask" in inputs else None

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # グリーディー生成
            pad_token_id=tokenizer.eos_token_id
        )
    full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    if full_text.startswith(prompt):
        generated_text = full_text[len(prompt):].strip()
    else:
        generated_text = full_text.strip()
    return generated_text

def compute_embedding(text, model, tokenizer, max_length):
    """
    テキストをトークナイズし、モデルの最終層隠れ状態の attention mask 付き平均により埋め込みを算出する。
    GPUメモリ消費を抑えるため、不要なテンソルは CPU へ移動し torch.cuda.empty_cache() で解放する。
    """
    tokens = tokenizer(text, return_tensors="pt", truncation=True)
    actual_length = tokens.input_ids.size(1)

    inputs = tokenizer(
        text, return_tensors="pt",
        padding="max_length",    # 実際のトークン数に合わせてパディング
        max_length=actual_length,
        truncation=True
    )
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.inference_mode():
        outputs = model(**inputs, output_hidden_states=True, return_dict=True)
    
    last_hidden = outputs.hidden_states[-1][0]
    attention_mask = inputs['attention_mask'][0].unsqueeze(-1)
    sum_mask = attention_mask.sum()
    if sum_mask.item() == 0:
        rep = last_hidden.mean(dim=0)
    else:
        rep = (last_hidden * attention_mask).sum(dim=0) / sum_mask

    embedding = rep.cpu().numpy()
    torch.cuda.empty_cache()
    return embedding

class Arrow3D(mpatches.FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super(Arrow3D, self).__init__((0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.get_proj())
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super(Arrow3D, self).draw(renderer)

    def do_3d_projection(self, renderer=None):
        # 3D 投影用に必要なメソッドを実装することで、mplot3d がこのオブジェクトを正しく投影できるようにする
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.get_proj())
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return np.min(zs)

def main(args):
    # ----- 基本設定 -----
    # ※環境に合わせてモデル名・パスなどを設定
    # model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
    data_args = DataArguments()
    random.seed(data_args.seed)
    
    model_name = args.model_name
    dir_path = f"/home/acd13972py/Cyclic_Reasoning/cot_reasoning/load_data/generated_extract_steps/{model_name}/{data_args.dataset}/target_layer_ratio={args.target_layer_ratio}"
    kmeans_model_path = os.path.join(dir_path, f"k-means-k={args.num_types}", f"{data_args.dataset}_k-means_{args.num_types}.pkl")
    train_embeddings_path = os.path.join(dir_path, f"{data_args.dataset}_embedding.npy")
    tokenizer_name_or_path = model_name
    model_name_or_path = model_name
    
    output_dir = f"{dir_path}/k={args.num_types}/reasoning_path2"
    os.makedirs(output_dir, exist_ok=True)
    # ---------------------

    # 学習済み KMeans モデルと背景埋め込みの読み込み
    with open(kmeans_model_path, 'rb') as f:
        kmeans = pickle.load(f)
    if not os.path.isfile(train_embeddings_path):
        raise FileNotFoundError(f"Training embeddings not found at {train_embeddings_path}")
    train_embeddings = np.load(train_embeddings_path).astype(np.float64)

    # モデルとトークナイザーのロード（fp16 モード）
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print(f"Loading model: {model_name_or_path}")
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    model.eval()
    
    # データセットのロード（gsm8k, math, aqua, svamp から選択）
    if data_args.dataset == "gsm8k":
        data_class = GSMData
    elif data_args.dataset == "math":
        data_class = MathData
    elif data_args.dataset == "aqua":
        data_class = AquaData
    elif data_args.dataset == "svamp":
        data_class = SVAMPData
    else:
        raise NotImplementedError
    dataset = data_class("test", [], prompt_template=None, tokenizer=tokenizer)
    
    # 指定された num_test 件数のみ利用（ランダムサブセット化）
    if len(dataset) > data_args.num_test:
        idx = random.choices(list(range(len(dataset))), k=data_args.num_test)
        new_x, new_y = [], []
        for i in idx:
            new_x.append(dataset[i]['x'])
            new_y.append(dataset[i]['y'])
        dataset.x = new_x
        dataset.y = new_y
    assert len(dataset) <= data_args.num_test
    dataloader = DataLoader(dataset, batch_size=data_args.batch_size, shuffle=False)
    
    double_check_results = []
    revision_results = []
    
    # バッチサイズ 1 で各サンプルについて Chain-of-Thought を生成し TSNE で可視化
    for i, batch in tqdm(enumerate(dataloader), desc="Processing batches"):
        x_text, y_text = batch['x'], batch['y']
        # プロンプトテンプレートに基づいて質問文作成
        x_text = [GSMK_QUERY_TEMPLATE.format(Question=x) for x in x_text]
        # バッチサイズ 1 を想定
        x_text = x_text[0]
        # Chain-of-Thought を生成
        generated_text = generate_chain_of_thought(x_text, model, tokenizer, max_new_tokens=8192)
        print("Generated Chain-of-Thought Answer:")
        print(generated_text)
        
        torch.cuda.empty_cache()

        # 生成された文章から各ステップを抽出（改行区切り）
        steps = [line.strip() for line in generated_text.split("\n") if line.strip() != ""]
        print("\nExtracted Steps:")
        for j, step in enumerate(steps, start=1):
            print(f"Step {j}: {step}")
        
        # 各ステップの埋め込みを計算
        step_embeddings_list = []
        for step in steps:
            emb = compute_embedding(step, model, tokenizer, tokenizer.model_max_length)
            step_embeddings_list.append(emb)
        # — right after stacking into step_embeddings —
        step_embeddings = np.stack(step_embeddings_list, axis=0)

        # 1) Quick stats to see if something is off:
        print("Embedding stats → min:", np.nanmin(step_embeddings),
            "max:", np.nanmax(step_embeddings))

        # 2) Sanity‐check for non‐finite values:
        if not np.all(np.isfinite(step_embeddings)):
            bad_steps = np.where(~np.isfinite(step_embeddings).all(axis=1))[0]
            print(f"⚠️ Non‐finite embeddings at steps: {bad_steps}")
            # You can inspect them in more detail if you want:
            for idx in bad_steps:
                print(f"  step {idx} =>", step_embeddings[idx])

        # 3) Replace NaN/Inf (and clip extremes if you like):
        step_embeddings = np.nan_to_num(
            step_embeddings,
            nan=0.0,
            posinf=1e6,    # or np.finfo(np.float64).max
            neginf=-1e6    # or np.finfo(np.float64).min
        )
        # Optional: clip everything to a reasonable range
        step_embeddings = np.clip(step_embeddings, -1e5, 1e5)

        # 4) Ensure dtype & contiguity (you already had this):
        step_embeddings = np.require(step_embeddings,
                                    dtype=np.float32,
                                    requirements=['C'])

        # now it should be safe to call:
        prompt_clusters = kmeans.predict(step_embeddings)

        print("\nKMeans Cluster Predictions for Prompt Steps:", prompt_clusters)
        
        ### Loop Detection
        double_checks, revisions = separate_double_and_revision_with_length(prompt_clusters)
        print(f"Double Checks: {double_checks}")
        print(f"Revisions: {revisions}")
        double_check_results.append(double_checks)
        revision_results.append(revisions)

        if args.is_visualize:
            ### Visualize Reasoning Steps
            
            # 背景埋め込みと各ステップ埋め込みを結合して TSNE による 2 次元射影
            combined_embeddings = np.concatenate([train_embeddings, step_embeddings], axis=0)
            tsne = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=30)
            combined_tsne = tsne.fit_transform(combined_embeddings.astype(np.float32))
            train_tsne = combined_tsne[:len(train_embeddings)]
            prompt_tsne = combined_tsne[len(train_embeddings):]
            
            # ------------- 2D 可視化（全体図＆ズーム図） -------------
            fig, axes = plt.subplots(1, 2, figsize=(16, 8))
            
            # 左側: 全体図
            ax_left = axes[0]
            ax_left.scatter(train_tsne[:, 0], train_tsne[:, 1],
                            s=3, c='lightgray', alpha=0.5, label='Training Data')
            scatter_left = ax_left.scatter(prompt_tsne[:, 0], prompt_tsne[:, 1],
                                        c=prompt_clusters, cmap='viridis', s=120,
                                        edgecolors='black', zorder=3, label='Reasoning Steps')
            for j in range(len(prompt_tsne) - 1):
                start = prompt_tsne[j]
                end = prompt_tsne[j+1]
                ax_left.annotate("",
                                xy=end, xycoords='data',
                                xytext=start, textcoords='data',
                                arrowprops=dict(arrowstyle="->", color='red', lw=2, shrinkA=5, shrinkB=5, alpha=0.8))
                ax_left.text(start[0], start[1], str(j+1), fontsize=14, color='black', zorder=4)
            ax_left.text(prompt_tsne[-1, 0], prompt_tsne[-1, 1], str(len(prompt_tsne)),
                        fontsize=14, color='black', zorder=4)
            ax_left.set_title("Chain-of-Thought: Full TSNE", fontsize=16)
            ax_left.set_xlabel("TSNE Dimension 1", fontsize=14)
            ax_left.set_ylabel("TSNE Dimension 2", fontsize=14)
            ax_left.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
            ax_left.legend(loc='upper right', fontsize=12)
            cbar_left = fig.colorbar(scatter_left, ax=ax_left, pad=0.02)
            cbar_left.set_label("KMeans Cluster ID", fontsize=14)
            cbar_left.ax.tick_params(labelsize=12)
            
            # 右側: ズーム図（Reasoning Steps 周辺）
            ax_right = axes[1]
            ax_right.scatter(train_tsne[:, 0], train_tsne[:, 1],
                            s=3, c='lightgray', alpha=0.5, label='Training Data')
            scatter_right = ax_right.scatter(prompt_tsne[:, 0], prompt_tsne[:, 1],
                                            c=prompt_clusters, cmap='viridis', s=120,
                                            edgecolors='black', zorder=3, label='Reasoning Steps')
            for j in range(len(prompt_tsne) - 1):
                start = prompt_tsne[j]
                end = prompt_tsne[j+1]
                ax_right.annotate("",
                                xy=end, xycoords='data',
                                xytext=start, textcoords='data',
                                arrowprops=dict(arrowstyle="->", color='red', lw=2, shrinkA=5, shrinkB=5, alpha=0.8))
                ax_right.text(start[0], start[1], str(j+1), fontsize=14, color='black', zorder=4)
            ax_right.text(prompt_tsne[-1, 0], prompt_tsne[-1, 1], str(len(prompt_tsne)),
                        fontsize=14, color='black', zorder=4)
            ax_right.set_title("Chain-of-Thought: Zoomed", fontsize=16)
            ax_right.set_xlabel("TSNE Dimension 1", fontsize=14)
            ax_right.set_ylabel("TSNE Dimension 2", fontsize=14)
            ax_right.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
            x_min, x_max = prompt_tsne[:, 0].min(), prompt_tsne[:, 0].max()
            y_min, y_max = prompt_tsne[:, 1].min(), prompt_tsne[:, 1].max()
            margin_x = 0.1 * (x_max - x_min)
            margin_y = 0.1 * (y_max - y_min)
            ax_right.set_xlim(x_min - margin_x, x_max + margin_x)
            ax_right.set_ylim(y_min - margin_y, y_max + margin_y)
            cbar_right = fig.colorbar(scatter_right, ax=ax_right, pad=0.02)
            cbar_right.set_label("KMeans Cluster ID", fontsize=14)
            cbar_right.ax.tick_params(labelsize=12)
            
            plt.tight_layout()
            plt.savefig(f"{output_dir}/cot_trajectory_{i}.png", dpi=300)
            plt.show()
            
            # ------------- 3D 可視化 -------------
            # 3D TSNE: n_components=3
            tsne_3d = TSNE(n_components=3, learning_rate='auto', init='random', perplexity=30)
            combined_tsne_3d = tsne_3d.fit_transform(combined_embeddings.astype(np.float32))
            train_tsne_3d = combined_tsne_3d[:len(train_embeddings)]
            prompt_tsne_3d = combined_tsne_3d[len(train_embeddings):]
            
            fig3d = plt.figure(figsize=(14, 12))
            ax3d = fig3d.add_subplot(111, projection='3d')
            ax3d.scatter(train_tsne_3d[:, 0], train_tsne_3d[:, 1], train_tsne_3d[:, 2],
                        s=3, c='lightgray', alpha=0.2, label='Training Data')
            scatter3d = ax3d.scatter(prompt_tsne_3d[:, 0], prompt_tsne_3d[:, 1], prompt_tsne_3d[:, 2],
                                    c=prompt_clusters, cmap='viridis', s=120, edgecolors='black', label='Reasoning Steps')
            for j in range(len(prompt_tsne_3d) - 1):
                xs = [prompt_tsne_3d[j, 0], prompt_tsne_3d[j+1, 0]]
                ys = [prompt_tsne_3d[j, 1], prompt_tsne_3d[j+1, 1]]
                zs = [prompt_tsne_3d[j, 2], prompt_tsne_3d[j+1, 2]]
                arrow = Arrow3D(xs, ys, zs, mutation_scale=20, lw=2, arrowstyle="-|>", color="red", alpha=0.8)
                ax3d.add_artist(arrow)
                ax3d.text(prompt_tsne_3d[j, 0], prompt_tsne_3d[j, 1], prompt_tsne_3d[j, 2], str(j+1),
                        fontsize=14, color='black')
            ax3d.text(prompt_tsne_3d[-1, 0], prompt_tsne_3d[-1, 1], prompt_tsne_3d[-1, 2], str(len(prompt_tsne_3d)),
                    fontsize=14, color='black')
            
            ax3d.set_title("Reasoning Trajectory", fontsize=18, pad=20)
            ax3d.set_xlabel("TSNE Dimension 1", fontsize=16)
            ax3d.set_ylabel("TSNE Dimension 2", fontsize=16)
            ax3d.set_zlabel("TSNE Dimension 3", fontsize=16)
            ax3d.legend(loc='upper right', fontsize=18)
            fig3d.colorbar(scatter3d, ax=ax3d, pad=0.1, label="KMeans Cluster ID", fontsize=16)
            
            plt.tight_layout()
            plt.savefig(f"{output_dir}/cot_trajectory_3d_{i}.png", dpi=300)
            plt.show()
            
            torch.cuda.empty_cache()
            
    # ループ検出結果をファイルに保存
    double_check_df = pd.DataFrame(double_check_results)
    double_check_df.to_csv(os.path.join(output_dir, "double_check_results.csv"), index=False)
    revision_df = pd.DataFrame(revision_results)
    revision_df.to_csv(os.path.join(output_dir, "revision_results.csv"), index=False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Visualize reasoning path")
    parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-32B", help="Model name")
    parser.add_argument("--num_test", type=int, default=100, help="Number of test samples")
    parser.add_argument("--num_types", type=int, default=10, help="Number of types")
    parser.add_argument("--target_layer_ratio", type=float, default=0.1, help="Target layer ratio")
    parser.add_argument("--is_visualize", action="store_true", help="Visualize reasoning path")
    args = parser.parse_args()
    main(args)
