import argparse
import json
import numpy as np
import os
import pandas as pd
import polars as pl
import torch
import tqdm
from client import apply_chat_template
from config import (
    GPQA_DIR, GPQA_MAX_LEN, GPQA_NUM_CHAINS, GPQA_PROBE_FREQ,
    MATH_DIR, MATH_MAX_LEN, MATH_NUM_CHAINS, MATH_PROBE_FREQ,
    MMLU_DIR, MMLU_MAX_LEN, MMLU_NUM_CHAINS, MMLU_PROBE_FREQ,
    GSM8K_DIR, GSM8K_MAX_LEN, GSM8K_NUM_CHAINS, GSM8K_PROBE_FREQ,
    MODEL_IDS,
)
from sklearn.model_selection import train_test_split
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import convert_math_data_setting_to_str, process_math_id, save_gzip_file, load_gzip_file


def get_string_embedding(s, model, tokenizer, device='cuda'):
    inputs = tokenizer(s, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    hidden_states = outputs.hidden_states
    all_hidden = torch.stack(hidden_states)
    assert(all_hidden.shape[1] == 1)
    all_hidden = all_hidden.squeeze(1)
    return all_hidden


def embed_math(model_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"):
    df = pd.read_csv(os.path.join(MATH_DIR, 'math3k.csv'))

    embed_dir = os.path.join(MATH_DIR, MODEL_IDS[model_id], 'embedding')
    os.makedirs(embed_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
    model = model.to(device)
    model.eval()

    pbar = tqdm.tqdm(total=df.shape[0])
    for _, row in df.iterrows():
        pbar.update(1)

        # write 11 files per problem
        # 1 for prompt last token embedding, 10 for per-chain intermediate embeddings
        # save all layers
        id = row['unique_id']
        print(id)
        pbar.set_description(f"Processing {id}")
        response_fname = os.path.join(MATH_DIR, MODEL_IDS[model_id], "response", f'{id}.json')
        if os.path.isfile(response_fname):
            problem_prompt = apply_chat_template(row['problem'], model_id)
            prompt_len = len(tokenizer.encode(problem_prompt))
            prompt_last_token_fname = os.path.join(embed_dir, f'{id}.prompt.lasttoken.npz')
            if not os.path.isfile(prompt_last_token_fname):
                prompt_embedding = get_string_embedding(problem_prompt, model, tokenizer, device)
                assert(prompt_embedding.shape[1] == prompt_len)

                if not os.path.isfile(prompt_last_token_fname):
                    prompt_embedding_last = prompt_embedding[:, -1, :]
                    # torch.save(prompt_embedding_last, prompt_last_token_fname)
                    # save_gzip_file(prompt_embedding_last, prompt_last_token_fname)
                    np.savez_compressed(prompt_last_token_fname, data=prompt_embedding_last.cpu().numpy())
            
            with open(response_fname, 'r') as f:
                responses = json.load(f)

            for chain_id in range(len(responses)):
                chain_fname = os.path.join(embed_dir, f'{id}.chain{chain_id}.npz')
                if not os.path.isfile(chain_fname):
                    full_text = problem_prompt + responses[chain_id]
                    full_embedding = get_string_embedding(full_text, model, tokenizer, device)
                    response_embedding = full_embedding[:, prompt_len:, :]
                    response_len = response_embedding.shape[1]
                    token_indices = list(range(MATH_PROBE_FREQ - 1, response_len, MATH_PROBE_FREQ))
                    selected_embedding = response_embedding[:, token_indices, :]
                    # torch.save(selected_embedding, chain_fname)
                    # save_gzip_file(selected_embedding, chain_fname)
                    np.savez_compressed(chain_fname, data=selected_embedding.cpu().numpy())

    pbar.close()


def embed_mmlu(model_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"):
    embed_dir = os.path.join(MMLU_DIR, MODEL_IDS[model_id], 'embedding')
    os.makedirs(embed_dir, exist_ok=True)

    mmlu = pd.read_csv(os.path.join(MMLU_DIR, 'mmlu.csv'))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
    model = model.to(device)
    model.eval()

    pbar = tqdm.tqdm(total=mmlu.shape[0])
    for _, row in mmlu.iterrows():
        pbar.update(1)
        id = row['unique_id']
        response_fname = os.path.join(MMLU_DIR, MODEL_IDS[model_id], "response", f'{id}.json')
        if os.path.isfile(response_fname):
            problem_prompt = apply_chat_template(row['problem'], model_id)
            prompt_len = len(tokenizer.encode(problem_prompt))
            prompt_last_token_fname = os.path.join(embed_dir, f'{id}.prompt.lasttoken.npz')
            if not os.path.isfile(prompt_last_token_fname):
                prompt_embedding = get_string_embedding(problem_prompt, model, tokenizer, device)
                assert(prompt_embedding.shape[1] == prompt_len)

                if not os.path.isfile(prompt_last_token_fname):
                    prompt_embedding_last = prompt_embedding[:, -1, :]
                    # torch.save(prompt_embedding_last, prompt_last_token_fname)
                    # save_gzip_file(prompt_embedding_last, prompt_last_token_fname)
                    np.savez_compressed(prompt_last_token_fname, data=prompt_embedding_last.cpu().numpy())
            
            with open(response_fname, 'r') as f:
                responses = json.load(f)
            
            for chain_id in range(len(responses)):
                chain_fname = os.path.join(embed_dir, f'{id}.chain{chain_id}.npz')
                if not os.path.isfile(chain_fname):
                    full_text = problem_prompt + responses[chain_id]
                    full_embedding = get_string_embedding(full_text, model, tokenizer, device)
                    response_embedding = full_embedding[:, prompt_len:, :]
                    response_len = response_embedding.shape[1]
                    token_indices = list(range(MMLU_PROBE_FREQ - 1, response_len, MMLU_PROBE_FREQ))
                    selected_embedding = response_embedding[:, token_indices, :]
                    # save_gzip_file(selected_embedding, chain_fname)
                    np.savez_compressed(chain_fname, data=selected_embedding.cpu().numpy())

    pbar.close()


def embed_gsm8k(model_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"):
    embed_dir = os.path.join(GSM8K_DIR, MODEL_IDS[model_id], 'embedding')
    os.makedirs(embed_dir, exist_ok=True)

    gsm = pd.read_csv(os.path.join(GSM8K_DIR, 'gsm8k.csv'))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
    model = model.to(device)
    model.eval()

    pbar = tqdm.tqdm(total=gsm.shape[0])
    gsm = gsm.sample(frac=1).reset_index(drop=True)
    for _, row in gsm.iterrows():
        id = row['unique_id']
        pbar.update(1)
        pbar.set_description(f"Processing {id}")
        response_fname = os.path.join(GSM8K_DIR, MODEL_IDS[model_id], "response", f'{id}.json')
        if os.path.isfile(response_fname):
            problem_prompt = apply_chat_template(row['problem'], model_id)
            prompt_len = len(tokenizer.encode(problem_prompt))

            prompt_last_token_fname = os.path.join(embed_dir, f'{id}.prompt.lasttoken.npz')
            if not os.path.isfile(prompt_last_token_fname):
                prompt_embedding = get_string_embedding(problem_prompt, model, tokenizer, device)
                assert(prompt_embedding.shape[1] == prompt_len)
                prompt_embedding_last = prompt_embedding[:, -1, :]
                # torch.save(prompt_embedding_last, prompt_last_token_fname)
                # save_gzip_file(prompt_embedding_last, prompt_last_token_fname)
                np.savez_compressed(prompt_last_token_fname, data=prompt_embedding_last.cpu().numpy())
            
            with open(response_fname, 'r') as f:
                responses = json.load(f)
            
            for chain_id in range(len(responses)):
                chain_fname = os.path.join(embed_dir, f'{id}.chain{chain_id}.npz')
                if not os.path.isfile(chain_fname):
                    full_text = problem_prompt + responses[chain_id]
                    full_embedding = get_string_embedding(full_text, model, tokenizer, device)
                    response_embedding = full_embedding[:, prompt_len:, :]
                    response_len = response_embedding.shape[1]
                    token_indices = list(range(GSM8K_PROBE_FREQ - 1, response_len, GSM8K_PROBE_FREQ))
                    selected_embedding = response_embedding[:, token_indices, :]
                    # save_gzip_file(selected_embedding, chain_fname)
                    np.savez_compressed(chain_fname, data=selected_embedding.cpu().numpy())

    pbar.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--dataset", type=str, required=True)
    parser.add_argument("-m", "--model", type=str, default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
    args = parser.parse_args()

    if args.dataset == "math":
        embed_math(model_id=args.model)
    elif args.dataset == "mmlu":
        embed_mmlu(model_id=args.model)
    elif args.dataset == "gsm8k":
        embed_gsm8k(model_id=args.model)
