from tqdm import tqdm
import json, pickle
import collections, time
import argparse
from typing import List, Optional, Tuple, Union
import numpy as np
import transformers
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity
from utils import analyze_graph, analyze_graph_v2
import sentence_transformers
import matplotlib.pyplot as plt
from nltk.tokenize import sent_tokenize, word_tokenize
from transformers import AutoTokenizer, T5EncoderModel

from load_data.preprocess import *
from load_data.supervised_dataset import SupervisedDataset, DataCollatorForSupervisedDataset
from load_data.k_shot_dataset import KshotDataset
from model.vae import VQ_VAE, VAE
from model.utils import model_name_mapping

import pandas as pd
import os

GSMK_QUERY_TEMPLATE = """
Solve the following math problem efficiently and clearly.
The last line of your response should be of the following format: 'The answer is: ANSWER.' (without quotes) where ANSWER is just the final number or expression that solves the problem.

{Question}
""".strip()

def extract_step_type(dataset_name:str, model_name_or_path:str, batch_size:int, tokenizer_name_or_path:str, \
                      model_max_length = 1024, selection_method='k-means', output_dir='extract_steps', \
                      cache_dir=None, num_types=50, df_path:str=None, target_layer_ratio=0.5):

    out_dir = f"{output_dir}/{model_name_or_path}/{dataset_name}/target_layer_ratio={target_layer_ratio}"
    
    model_name_or_path = model_name_mapping(model_name_or_path)
    
    tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name_or_path, legacy=False)
    tokenizer.model_max_length = model_max_length

    if tokenizer.pad_token is None:
        tokenizer.pad_token_id = 0

    # df_path == csv
    if df_path.endswith('.csv'):
        df = pd.read_csv(df_path)
    # df_path == json
    else:
        ds = load_dataset(df_path)
        df = pd.DataFrame(ds["train"])
    
    ### ファイルの出力先
    embedding_file = f"{out_dir}/{dataset_name}_embedding.npy"
    text_file = f"{out_dir}/{dataset_name}_text.npy"
    example_id_file = f"{out_dir}/{dataset_name}_example_id.npy"
    os.makedirs(out_dir, exist_ok=True)

    ### モデルの読み込み
    embedding_model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name_or_path, trust_remote_code=True, cache_dir=cache_dir,
        torch_dtype=torch.float16, device_map="auto")
    embedding_model.eval()
    
    target_layer = int(embedding_model.config.num_hidden_layers * target_layer_ratio)
    
    ### 埋め込みの計算
    step_embeddings = []
    solution_steps = []
    example_ids = []
    ex_id = 0
    for index, row in tqdm(df.iterrows()):
        # print(batch)
        examples = []
        questions = []
        step_text = []
        
        if 'Question' in row:
            x = row['Question']
        elif 'question' in row:
            x = row['question']
        else:
            raise ValueError(f"Invalid column name: {df_path}")
        
        if 'generated_text' in row:
            steps = str(row['generated_text']).strip().split('\n')
        elif 'text' in row:
            steps = str(row['text']).strip().split('\n')
        else:
            raise ValueError(f"Invalid column name: {df_path}")
        
        steps = [step.strip() for step in steps if len(step.strip())>5]
        # print("----steps----")
        # print(steps)
        if len(steps) > 1:
            questions.append(x.strip().split('\n'))
            examples.append(x + '\n'.join(steps[:-1]))
            step_text.append(steps)
        else:
            continue
        inputs = tokenizer(examples, return_tensors="pt", padding="longest", max_length=model_max_length, truncation=True).to('cuda')
        with torch.no_grad():
            outputs = embedding_model(**inputs, output_hidden_states=True, return_dict=True)
        nan_index = torch.where(torch.isnan(outputs.hidden_states[target_layer][0]))[0]
        # 最後のlayerのhidden_statesを取得
        # hidden_statesは[layer, batch_size, seq_len, hidden_size]の形状
        target_hidden_states = outputs.hidden_states[target_layer]
        # 各分割用トークンIDをリストとして格納
        split_ids = [
            tokenizer("\n")['input_ids'][-1],
            tokenizer("?\n")['input_ids'][-1],
            tokenizer(".\n")['input_ids'][-1],
            tokenizer("!\n")['input_ids'][-1],
            tokenizer("\n\n")['input_ids'][-1],
            tokenizer(".\n\n")['input_ids'][-1],
            tokenizer("?\n\n")['input_ids'][-1],
            tokenizer("!\n\n")['input_ids'][-1],
            tokenizer(".\n\n\n")['input_ids'][-1],
            tokenizer("?\n\n\n")['input_ids'][-1],
            tokenizer("!\n\n\n")['input_ids'][-1],
        ]

        # torch.tensor に変換（デバイスも合わせる）
        split_ids_tensor = torch.tensor(split_ids, device=inputs['input_ids'].device)

        # torch.isin を使って、inputs の各要素が split_ids 内にあるかチェック
        mask = torch.isin(inputs['input_ids'], split_ids_tensor)

        # マスクに対して累積和を取る
        # step_mask : [batch_size, seq_len], [0,0,0,0,1,1,1,1,1,1,2,2,2...]のようになる，
        step_mask = torch.cumsum(mask, dim=-1)
        step_mask *= inputs["attention_mask"]
        # バッチごとの処理
        for hidden, mask, q, steps in zip(target_hidden_states, step_mask, questions, step_text):
            # hidden : [seq_len, hidden_size]
            # mask : [seq_len]
            # q : [1], Question text
            # steps : [Steps], Step text
            example_rep = []
            num_steps = torch.max(mask) + 1
            # 質問の\nを考慮
            start = min(len(q), num_steps-1)
            # print(num_steps.item())
            for j in range(start, num_steps):
                step_j_mask = (mask == j).int().float()
                step_j_rep = (hidden * step_j_mask.unsqueeze(-1)).sum(0)
                step_len = step_j_mask.sum()
                if step_len > 0:
                    rep = (step_j_rep/step_len).cpu().numpy()
                    if np.isnan(rep).sum() == 0:
                        example_rep.append(rep)
                        solution_steps.append(steps[j-start])
                else:
                    assert False, "current step is empty"
            if len(example_rep) > 0:
                example_rep = np.stack(example_rep, axis=0)
                step_embeddings.append(example_rep)
                example_ids += [ex_id for _ in range(len(example_rep))]
                ex_id += 1
            else:
                assert False, "no step embeddings"
        
    step_embeddings = np.concatenate(step_embeddings, axis=0)
    solution_steps = np.array(solution_steps)
    example_ids = np.array(example_ids)
    
    # [all_steps, hidden_size]
    print("step_embeddings.shape: ", step_embeddings.shape)
    # [all_steps]
    print("solution_steps.shape: ", solution_steps.shape)
    # [all_steps]
    print("example_ids.shape: ", example_ids.shape)

    assert step_embeddings.shape[0] == solution_steps.shape[0] == example_ids.shape[0]
    np.save(embedding_file, step_embeddings)
    np.save(text_file, solution_steps)
    np.save(example_id_file, example_ids)

    
    with open(f"{out_dir}/{dataset_name}_text.json", 'w') as wf:
        json.dump(solution_steps.tolist(), wf)
    
    out_dir = f"{out_dir}/{selection_method}-k={num_types}"
    cluster_model_file = f"{out_dir}/{dataset_name}_{selection_method}_{num_types}.pkl"
    os.makedirs(f"{out_dir}", exist_ok=True)
    step_embeddings = np.float32(step_embeddings)
    print("k-means start")
    cluster_model = KMeans(n_clusters=num_types, n_init=10, random_state=0).fit(step_embeddings)
    print("k-means end")
    print("cluster_model: ", cluster_model)
    print("cluster_model.labels_: ", cluster_model.labels_)
    print("cluster_model.cluster_centers_: ", cluster_model.cluster_centers_)

    with open(cluster_model_file, 'wb') as f:
        pickle.dump(cluster_model, f)

    all_preds = cluster_model.labels_
    print(f"all_preds.shape: {all_preds.shape}")
    assert len(all_preds) == len(solution_steps)

    np.save(f"{out_dir}/clusters.npy", all_preds)

    step_ids = np.arange(len(solution_steps))

    for i in range(num_types):
        # print(f"cluster {i}: ", np.sum(cluster_model.labels_==i))
        with open(f"{out_dir}/{dataset_name}_{num_types}_{i}.txt", 'w') as f:
            f.write('\n'.join(list(solution_steps[cluster_model.labels_==i])))
    tsne_file = f"{out_dir}/tsne.npy"

    X = TSNE(n_components=2, learning_rate='auto',
                    init='random', perplexity=3).fit_transform(np.float32(step_embeddings))
    np.save(tsne_file, X)

    plt.scatter(X[:, 0], X[:, 1], c=all_preds, s=2, cmap='viridis')
    plt.title(f"Number of Clusters = {num_types}")
    plt.savefig(f"{out_dir}/kmeans.png")
    plt.close()
    
    loop_detection_results = {}
    
    # バッチサイズ 1 で各サンプルについて Chain-of-Thought を生成し TSNE で可視化
    # for i, batch in tqdm(enumerate(dataloader), desc="Processing batches"):
    for index, row in tqdm(df.iterrows()):
        # print(batch)
        step_embeddings = []
        examples = []
        questions = []
        step_text = []
        
        if 'Question' in row:
            x = row['Question']
        elif 'question' in row:
            x = row['question']
        else:
            raise ValueError(f"Invalid column name: {df_path}")
        
        if 'generated_text' in row:
            steps = str(row['generated_text']).strip().split('\n')
        elif 'text' in row:
            steps = str(row['text']).strip().split('\n')
        else:
            raise ValueError(f"Invalid column name: {df_path}")
        print(f"steps: {steps}")
        
        steps = [step.strip() for step in steps if len(step.strip())>5]
        # print("----steps----")
        # print(steps)
        if len(steps) > 1:
            questions.append(x.strip().split('\n'))
            examples.append(x + '\n'.join(steps[:-1]))
            step_text.append(steps)
        else:
            continue
        inputs = tokenizer(examples, return_tensors="pt", padding="longest", max_length=model_max_length, truncation=True).to('cuda')
        with torch.no_grad():
            outputs = embedding_model(**inputs, output_hidden_states=True, return_dict=True)
        nan_index = torch.where(torch.isnan(outputs.hidden_states[target_layer][0]))[0]
        # 最後のlayerのhidden_statesを取得
        # hidden_statesは[layer, batch_size, seq_len, hidden_size]の形状
        target_hidden_states = outputs.hidden_states[target_layer]
        # 各分割用トークンIDをリストとして格納
        split_ids = [
            tokenizer("\n")['input_ids'][-1],
            tokenizer("?\n")['input_ids'][-1],
            tokenizer(".\n")['input_ids'][-1],
            tokenizer("!\n")['input_ids'][-1],
            tokenizer("\n\n")['input_ids'][-1],
            tokenizer(".\n\n")['input_ids'][-1],
            tokenizer("?\n\n")['input_ids'][-1],
            tokenizer("!\n\n")['input_ids'][-1],
            tokenizer(".\n\n\n")['input_ids'][-1],
            tokenizer("?\n\n\n")['input_ids'][-1],
            tokenizer("!\n\n\n")['input_ids'][-1],
        ]

        # torch.tensor に変換（デバイスも合わせる）
        split_ids_tensor = torch.tensor(split_ids, device=inputs['input_ids'].device)

        # torch.isin を使って、inputs の各要素が split_ids 内にあるかチェック
        mask = torch.isin(inputs['input_ids'], split_ids_tensor)

        # マスクに対して累積和を取る
        # step_mask : [batch_size, seq_len], [0,0,0,0,1,1,1,1,1,1,2,2,2...]のようになる，
        step_mask = torch.cumsum(mask, dim=-1)
        step_mask *= inputs["attention_mask"]
        # バッチごとの処理
        for hidden, mask, q, steps in zip(target_hidden_states, step_mask, questions, step_text):
            # hidden : [seq_len, hidden_size]
            # mask : [seq_len]
            # q : [1], Question text
            # steps : [Steps], Step text
            example_rep = []
            num_steps = torch.max(mask) + 1
            # 質問の\nを考慮
            start = min(len(q), num_steps-1)
            # print(num_steps.item())
            for j in range(start, num_steps):
                step_j_mask = (mask == j).int().float()
                step_j_rep = (hidden * step_j_mask.unsqueeze(-1)).sum(0)
                step_len = step_j_mask.sum()
                if step_len > 0:
                    rep = (step_j_rep/step_len).cpu().numpy()
                    if np.isnan(rep).sum() == 0:
                        example_rep.append(rep)
                else:
                    assert False, "current step is empty"
            if len(example_rep) > 0:
                example_rep = np.stack(example_rep, axis=0)
                step_embeddings.append(example_rep)
            else:
                assert False, "no step embeddings"
        step_embeddings = np.concatenate(step_embeddings, axis=0)

        # 1) Quick stats to see if something is off:
        print("Embedding stats → min:", np.nanmin(step_embeddings),
            "max:", np.nanmax(step_embeddings))

        # 2) Sanity‐check for non‐finite values:
        if not np.all(np.isfinite(step_embeddings)):
            bad_steps = np.where(~np.isfinite(step_embeddings).all(axis=1))[0]
            print(f"⚠️ Non‐finite embeddings at steps: {bad_steps}")
            # You can inspect them in more detail if you want:
            for idx in bad_steps:
                print(f"  step {idx} =>", step_embeddings[idx])

        # 3) Replace NaN/Inf (and clip extremes if you like):
        step_embeddings = np.nan_to_num(
            step_embeddings,
            nan=0.0,
            posinf=1e6,    # or np.finfo(np.float64).max
            neginf=-1e6    # or np.finfo(np.float64).min
        )
        # Optional: clip everything to a reasonable range
        step_embeddings = np.clip(step_embeddings, -1e5, 1e5)

        # 4) Ensure dtype & contiguity (you already had this):
        step_embeddings = np.require(step_embeddings,
                                    dtype=np.float32,
                                    requirements=['C'])

        # now it should be safe to call:
        prompt_clusters = cluster_model.predict(step_embeddings)
        
        # distance of each step
        distance_list = []
        for i in range(len(step_embeddings)-1):
            distance_list.append(np.linalg.norm(step_embeddings[i] - step_embeddings[i+1]))
        distance_list = np.array(distance_list)
        print(f"Distance of each step: {distance_list}")
        assert len(distance_list) == len(prompt_clusters)-1

        print("\nKMeans Cluster Predictions for Prompt Steps:", prompt_clusters)
        
        ### Loop Detection
        loop_exists, loop_count, diameter, avg_clustering, avg_path_length, clustering_norm = analyze_graph_v2(prompt_clusters, distance_list)
        print(f"Loop Detection: {'存在' if loop_exists else '存在しない'}")
        print(f"完全なループ回数: {loop_count}")
        print(f"直径: {diameter}")
        print(f"平均クラスタリング係数: {avg_clustering}")
        print(f"平均パス長: {avg_path_length}")
        print(f"クラスタリング正規化: {clustering_norm}")
        loop_detection_results[index] = {
            "loop_exists": loop_exists,
            "loop_count": loop_count,
            "diameter": diameter,
            "avg_clustering": avg_clustering,
            "avg_path_length": avg_path_length,
            "clustering_norm": clustering_norm
        }
        
    with open(f"{out_dir}/loop_detection_results_v2.json", 'w') as f:
        json.dump(loop_detection_results, f)
        
    loop_ratio = [1 if loop_detection_results[i]['loop_exists'] else 0 for i in loop_detection_results.keys()]
    print("-"*100)
    print(f"loop_ratio: {sum(loop_ratio) / len(loop_ratio)}")
    print("-"*100)


    
if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='gsm8k', help='dataset name')
    parser.add_argument('--model_name_or_path', type=str, default='deepseek-ai/DeepSeek-R1-Distill-Qwen-32B', help='model name or path')
    parser.add_argument('--tokenizer_name_or_path', type=str, default='deepseek-ai/DeepSeek-R1-Distill-Qwen-32B', help='tokenizer name or path')
    parser.add_argument('--batch_size', type=int, default=32, help='batch size')
    parser.add_argument('--model_max_length', type=int, default=8192, help='model max length')
    parser.add_argument('--selection_method', type=str, default='k-means',  choices=['k-means'])
    parser.add_argument('--output_dir', type=str, default='load_data/generated_extract_steps', help='output dir')
    parser.add_argument('--cache_dir', type=str, default=None, help='cache dir')
    parser.add_argument('--num_types', type=int, default=200, help='number of reasoning types')
    parser.add_argument('--df_path', type=str, default=None, help='df path')
    parser.add_argument('--target_layer_ratio', type=float, default=0.5, help='target layer ratio')

    args = parser.parse_args()

    extract_step_type(args.dataset, args.model_name_or_path, args.batch_size, args.tokenizer_name_or_path,  
                      args.model_max_length, args.selection_method,
                      args.output_dir, args.cache_dir, args.num_types, args.df_path, args.target_layer_ratio)