import os
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.manifold import TSNE
from dataclasses import dataclass, field
from tqdm import tqdm
import transformers
from torch.utils.data import DataLoader
import random
import transformers

# データセット系（各自の実装に合わせて調整）
from load_data.preprocess import GSMData, MathData, AquaData, SVAMPData
from load_data.k_shot_dataset import KshotDataset
import calculator
from model.peft_model import MyPeftModelForCausalLM
from model.utils import model_name_mapping

torch.backends.cudnn.benchmark = True

# プロンプトテンプレート： GSM8K の各問題に対して Chain-of-Thought を生成する
GSMK_QUERY_TEMPLATE = """
Solve the following math problem efficiently and clearly.  The last line of your response should be of the following format: 'The answer is: ANSWER.' (without quotes) where ANSWER is just the final number or expression that solves the problem.

{Question}
""".strip()

@dataclass
class ModelArguments:
    model_name_or_path: str = "gpt2"
    base_model_name_or_path: str = "gpt2"
    cache_dir: str = None
    output_dir: str = None
    max_length: int = 512
    decoding_scheme: str = "greedy"
    load_in_8bit: bool = False
    use_calculator: bool = False
    parameter_efficient_mode: str = 'none'  # choices: "none", "prompt-tuning", "lora", "lora+prompt-tuning"
    hf_hub_token: str = None
    enable_cpu_offload: bool = False
    flash_attention: bool = True

@dataclass
class DataArguments:
    dataset: str = "gsm8k"  # "gsm8k", "math", "aqua", "svamp"
    batch_size: int = 1
    use_demonstrations: bool = False
    demo_selection: str = "uniform"
    candidate_size: int = 100
    k_shot: int = 4
    seed: int = 42
    num_test: int = 100  # 可視化するテスト件数（必要に応じて調整）
    prompt_template: str = None
    embedding_model_name: str = 'all-mpnet-base-v2'

def generate_chain_of_thought(prompt, model, tokenizer, max_new_tokens=256, temperature=0.7):
    """
    プロンプトを入力し、Chain-of-Thought の回答を生成する。
    生成結果にプロンプトが含まれている場合はその部分を削除して返す。
    """
    inputs = tokenizer(prompt, return_tensors="pt")
    device = next(model.parameters()).device
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device) if "attention_mask" in inputs else None

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # グリーディー生成
            pad_token_id=tokenizer.eos_token_id
        )
    full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    if full_text.startswith(prompt):
        generated_text = full_text[len(prompt):].strip()
    else:
        generated_text = full_text.strip()
    return generated_text

def compute_embedding(text, model, tokenizer, max_length):
    """
    テキストをトークナイズし、モデルの最終層隠れ状態の attention mask 付き平均により埋め込みを算出する。
    GPUメモリ消費を抑えるため、不要なテンソルは CPU へ移動し torch.cuda.empty_cache() で解放する。
    """
    tokens = tokenizer(text, return_tensors="pt", truncation=True)
    actual_length = tokens.input_ids.size(1)

    inputs = tokenizer(
        text, return_tensors="pt",
        padding="max_length",    # 実際のトークン数に合わせてパディング
        max_length=actual_length,
        truncation=True
    )
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.inference_mode():
        outputs = model(**inputs, output_hidden_states=True, return_dict=True)
    
    last_hidden = outputs.hidden_states[-1][0]
    attention_mask = inputs['attention_mask'][0].unsqueeze(-1)
    sum_mask = attention_mask.sum()
    if sum_mask.item() == 0:
        rep = last_hidden.mean(dim=0)
    else:
        rep = (last_hidden * attention_mask).sum(dim=0) / sum_mask

    embedding = rep.cpu().numpy()
    torch.cuda.empty_cache()
    return embedding

# 3D 用の Arrow3D クラス（mpl_toolkits.mplot3d.proj3d, FancyArrowPatch を利用）
from mpl_toolkits.mplot3d import Axes3D, proj3d
import matplotlib.patches as mpatches

class Arrow3D(mpatches.FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super(Arrow3D, self).__init__((0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.get_proj())
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super(Arrow3D, self).draw(renderer)

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.get_proj())
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return np.min(zs)

def main():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments))
    model_args, data_args = parser.parse_args_into_dataclasses()
    # ----- 基本設定 -----
    # ※環境に合わせてモデル名・パスなどを設定
    model_name = model_args.model_name_or_path
    dir_path = f"/home/pj24002027/ku40003286/Cyclic_Reasoning/cot_reasoning/load_data/extract_steps/{model_name}/gsm8k"
    kmeans_model_path = os.path.join(dir_path, "k-means-k=200", "gsm8k_k-means_200.pkl")
    train_embeddings_path = os.path.join(dir_path, "gsm8k_embedding.npy")
    tokenizer_name_or_path = model_name
    model_name_or_path = model_name

    parser = transformers.HfArgumentParser((ModelArguments, DataArguments))
    model_args, data_args = parser.parse_args_into_dataclasses()
    random.seed(data_args.seed)
    
    output_dir = f"simple_reasoning_path/{model_name}/gsm8k"
    os.makedirs(output_dir, exist_ok=True)
    # ---------------------

    # 学習済み KMeans モデルと背景埋め込みの読み込み
    with open(kmeans_model_path, 'rb') as f:
        kmeans = pickle.load(f)
    if not os.path.isfile(train_embeddings_path):
        raise FileNotFoundError(f"Training embeddings not found at {train_embeddings_path}")
    train_embeddings = np.load(train_embeddings_path).astype(np.float64)

    # モデルとトークナイザーのロード（fp16 モード）
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print(f"Loading model: {model_name_or_path}")
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    model.eval()
    
    # データセットのロード（gsm8k, math, aqua, svamp から選択）
    if data_args.dataset == "gsm8k":
        data_class = GSMData
    elif data_args.dataset == "math":
        data_class = MathData
    elif data_args.dataset == "aqua":
        data_class = AquaData
    elif data_args.dataset == "svamp":
        data_class = SVAMPData
    else:
        raise NotImplementedError
    dataset = data_class("test", [], prompt_template=None, tokenizer=tokenizer)
    
    # 指定された num_test 件数のみ利用（ランダムサブセット化）
    if len(dataset) > data_args.num_test:
        idx = random.choices(list(range(len(dataset))), k=data_args.num_test)
        new_x, new_y = [], []
        for i in idx:
            new_x.append(dataset[i]['x'])
            new_y.append(dataset[i]['y'])
        dataset.x = new_x
        dataset.y = new_y
    assert len(dataset) <= data_args.num_test
    dataloader = DataLoader(dataset, batch_size=data_args.batch_size, shuffle=False)
    
    # バッチサイズ 1 で各サンプルについて Chain-of-Thought を生成し可視化
    for i, batch in tqdm(enumerate(dataloader), desc="Processing batches"):
        x_text, y_text = batch['x'], batch['y']
        # プロンプトテンプレートに基づいて質問文作成
        x_text = [GSMK_QUERY_TEMPLATE.format(Question=x) for x in x_text]
        # バッチサイズ 1 を想定
        x_text = x_text[0]
        # Chain-of-Thought を生成
        generated_text = generate_chain_of_thought(x_text, model, tokenizer, max_new_tokens=8192)
        print("Generated Chain-of-Thought Answer:")
        print(generated_text)
        
        torch.cuda.empty_cache()

        # 生成された文章から各ステップを抽出（改行区切り）
        steps = [line.strip() for line in generated_text.split("\n") if line.strip() != ""]
        print("\nExtracted Steps:")
        for j, step in enumerate(steps, start=1):
            print(f"Step {j}: {step}")
        
        # 各ステップの埋め込みを計算
        step_embeddings_list = []
        for step in steps:
            emb = compute_embedding(step, model, tokenizer, tokenizer.model_max_length)
            step_embeddings_list.append(emb)
        step_embeddings = np.stack(step_embeddings_list, axis=0)
        
        # 埋め込み配列の型と連続性を保証
        step_embeddings = np.require(step_embeddings, dtype=np.float64, requirements=['C'])
        if kmeans.cluster_centers_.dtype != np.float64 or not kmeans.cluster_centers_.flags['C_CONTIGUOUS']:
            kmeans.cluster_centers_ = np.require(kmeans.cluster_centers_, dtype=np.float64, requirements=['C'])
        
        # KMeans によるクラスタ予測（各ステップごとのクラスタIDを取得）
        prompt_clusters = kmeans.predict(step_embeddings)
        print("\nKMeans Cluster Predictions for Prompt Steps:", prompt_clusters)
        
        # ----- クラスタ中心の軌跡のみを可視化 -----
        # 各推論ステップは予測クラスタに対応するクラスタ中心に置き換える
        cluster_path = [kmeans.cluster_centers_[cid] for cid in prompt_clusters]
        # 連続する重複クラスタは除去
        unique_cluster_path = [cluster_path[0]]
        unique_cluster_ids = [prompt_clusters[0]]
        for idx_step in range(1, len(cluster_path)):
            if prompt_clusters[idx_step] != prompt_clusters[idx_step - 1]:
                unique_cluster_path.append(cluster_path[idx_step])
                unique_cluster_ids.append(prompt_clusters[idx_step])
        unique_cluster_path = np.stack(unique_cluster_path, axis=0)
        
        # TSNE により 2 次元に射影（サンプル数が少ない場合 perplexity は低めに設定）
        tsne_path = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=5).fit_transform(unique_cluster_path.astype(np.float32))
        
        # 2D 可視化: クラスタ中心軌跡
        fig2, ax2 = plt.subplots(figsize=(8, 8))
        ax2.scatter(tsne_path[:, 0], tsne_path[:, 1], s=120, c='blue', edgecolors='black', zorder=3, label='Cluster Centers')
        for j in range(len(tsne_path) - 1):
            start = tsne_path[j]
            end = tsne_path[j+1]
            ax2.annotate("",
                         xy=end, xycoords='data',
                         xytext=start, textcoords='data',
                         arrowprops=dict(arrowstyle="->", color='red', lw=2, shrinkA=5, shrinkB=5, alpha=0.8))
            ax2.text(start[0], start[1], str(unique_cluster_ids[j]), fontsize=14, color='black', zorder=4)
        ax2.text(tsne_path[-1, 0], tsne_path[-1, 1], str(unique_cluster_ids[-1]), fontsize=14, color='black', zorder=4)
        ax2.set_title("Reasoning Path Cluster Centers (2D TSNE)")
        ax2.set_xlabel("TSNE Dimension 1")
        ax2.set_ylabel("TSNE Dimension 2")
        ax2.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
        plt.tight_layout()
        plt.savefig(f"{output_dir}/cot_cluster_path_{i}.png", dpi=300)
        plt.show()
        
        # 3D 可視化: 同様にクラスタ中心軌跡のみ
        tsne_path_3d = TSNE(n_components=3, learning_rate='auto', init='random', perplexity=5).fit_transform(unique_cluster_path.astype(np.float32))
        fig3d = plt.figure(figsize=(10, 8))
        ax3d = fig3d.add_subplot(111, projection='3d')
        ax3d.scatter(tsne_path_3d[:, 0], tsne_path_3d[:, 1], tsne_path_3d[:, 2],
                     s=120, c='blue', edgecolors='black', label='Cluster Centers')
        scatter3d = ax3d.scatter(tsne_path_3d[:, 0], tsne_path_3d[:, 1], tsne_path_3d[:, 2],
                                 c=unique_cluster_ids, cmap='viridis', s=120, edgecolors='black', label='Reasoning Steps')
        for j in range(len(tsne_path_3d) - 1):
            xs = [tsne_path_3d[j, 0], tsne_path_3d[j+1, 0]]
            ys = [tsne_path_3d[j, 1], tsne_path_3d[j+1, 1]]
            zs = [tsne_path_3d[j, 2], tsne_path_3d[j+1, 2]]
            arrow = Arrow3D(xs, ys, zs, mutation_scale=20, lw=2, arrowstyle="-|>", color="red", alpha=0.8)
            ax3d.add_artist(arrow)
            ax3d.text(tsne_path_3d[j, 0], tsne_path_3d[j, 1], tsne_path_3d[j, 2], str(unique_cluster_ids[j]), fontsize=14, color='black')
        ax3d.text(tsne_path_3d[-1, 0], tsne_path_3d[-1, 1], tsne_path_3d[-1, 2], str(unique_cluster_ids[-1]), fontsize=14, color='black')
        ax3d.set_title("Reasoning Path Cluster Centers (3D TSNE)", fontsize=18, pad=20)
        ax3d.set_xlabel("TSNE Dimension 1", fontsize=16)
        ax3d.set_ylabel("TSNE Dimension 2", fontsize=16)
        ax3d.set_zlabel("TSNE Dimension 3", fontsize=16)
        fig3d.colorbar(scatter3d, ax=ax3d, pad=0.1, label="N/A")
        plt.tight_layout()
        plt.savefig(f"{output_dir}/cot_cluster_path_3d_{i}.png", dpi=300)
        plt.show()
        
        torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
