import os
from tqdm import tqdm
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import params
from inference_utils import MODEL_PATH_DICT
import argparse
from hashlib import md5
from evaluate import load_wiki, load_truthfulqa, load_biography

def hash_answer(answer):
    return md5(answer.encode()).hexdigest()[:8]

def format_chat_prompt(question, tokenizer):
    if hasattr(tokenizer, 'apply_chat_template'):
        return tokenizer.apply_chat_template(
            [{"role": "user", "content": question}],
            tokenize=False,
            add_generation_prompt=True
        )
    return question.strip()


def load_dataset_by_name(train_data, num_train):
    if train_data == 'truthful_qa':
        ds = load_truthfulqa(split='validation', max_sample_num=num_train)
    elif train_data == 'bio':
        ds = load_biography(split='validation', max_sample_num=num_train)
    elif train_data == 'wiki':
        ds = load_wiki(data_split='test', max_sample_num=num_train)
    else:
        raise Exception(f'{train_data} is not supported')
    
    return ds


def precompute_space(
    model_name,
    model,
    tokenizer,
    base_embed_model,
    train_data='',
    output_dir='./truthfulqa_cache',
    num_train=100,
    batch_size=16,
    compute_context_types=['full', 'chunk_8', 'chunk_16'],
    skip_if_exists=True,
):
    device = model.device
    embed_model = SentenceTransformer(base_embed_model, device=device)
    os.makedirs(output_dir, exist_ok=True)

    ds = load_dataset_by_name(train_data, num_train)      
    train_indices = list(range(num_train))

    context_types = []
    for t in compute_context_types:
        if t == 'full':
            context_types.append(('full', None))
        elif t.startswith('chunk_'):
            chunk_size = int(t.split('_')[1])
            context_types.append((t, chunk_size))
        else:
            raise Exception(f'Context type {t} is not supported')
    
    # (2) skip computation if all final outputs exist
    if skip_if_exists:
        all_exist = all(
            os.path.exists(os.path.join(output_dir, f"{train_data}_{model_name}_{base_embed_model}_context_{ctype}.pt"))
            for ctype, _ in context_types
        )
        if all_exist:
            print("[Skipped] All output files already exist.")
            return

    context_results = {ctype: [] for ctype, _ in context_types}
    fallback_counter = 0
    
    print(">> Context types:", context_types)
    print(">> Init context_results:", context_results.keys())

    for batch_start in tqdm(range(0, num_train, batch_size), desc="Processing Batches"):
        batch_end = min(batch_start + batch_size, num_train)
        batch_indices = train_indices[batch_start:batch_end]
        batch_prompts = [format_chat_prompt(ds[i]['question'], tokenizer) for i in batch_indices]

        # (3) process each batch
        for j, i in enumerate(batch_indices):
            # parse datasets
            item = ds[i]
            prompt = batch_prompts[j]
            input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
            
            if (train_data == 'truthful_qa') or (train_data == ''):
                correct_answers = list(set(item["correct_answers"]))
            elif train_data == 'bio':
                correct_answers = [item['answer']]
            elif train_data == 'wiki':
                correct_answers = item['correct_answers']
            else:
                raise Exception(f'{train_data} is not supported')

            # (4) iterate over each answer and context type
            for label, answer_list in [("correct", correct_answers)]:
                for answer in answer_list:
                    answer_hash = hash_answer(answer)
                    record_id = f"{i}_{label}_{answer_hash}"

                    target_ids_tensor = tokenizer(answer, return_tensors='pt').input_ids.to(device)
                    target_ids = target_ids_tensor.view(-1).tolist()

                    try:
                        full_input = torch.cat([input_ids, target_ids_tensor], dim=1)
                        with torch.no_grad():
                            outputs = model(full_input, output_hidden_states=False)
                        logits = outputs.logits[:, input_ids.shape[1]-1:-1, :].squeeze(0).cpu().half()
                    except Exception as e:
                        print(f"[Skipped] QID {i} — label={label}\n  Prompt: {prompt}\n  Answer: {answer}\n  Error: {e}")
                        continue

                    for step in range(len(target_ids)):
                        prev_context_ids = torch.cat([input_ids, target_ids_tensor[:, :step]], dim=1).squeeze(0)

                        for context_type, chunk_size in context_types:
                            if context_type == 'full':
                                context_ids = prev_context_ids
                            else:
                                context_ids = prev_context_ids[-chunk_size:]

                            try:
                                decoded_text = tokenizer.decode(context_ids, skip_special_tokens=True)
                                context_embedding = F.normalize(
                                    embed_model.encode(decoded_text, convert_to_tensor=True), dim=-1
                                ).half().cpu()
                                embedding_valid = True
                            except Exception as e:
                                context_embedding = torch.zeros(embed_model.get_sentence_embedding_dimension()).half()
                                embedding_valid = False
                                fallback_counter += 1

                            context_results[context_type].append({
                                "record_id": record_id,
                                "question_id": i,
                                "question": item["question"],
                                "answer": answer,
                                "label": label,
                                "step": step,
                                "context_type": context_type,
                                "context_embedding": context_embedding,
                                "embedding_valid": embedding_valid,
                                "logits": logits[step]
                            })
                            
                            print(f"[Added] {record_id} | step={step} | type={context_type} | ctx len={len(context_ids)}")

                    del outputs, logits, target_ids_tensor, full_input
                    torch.cuda.empty_cache()

        # Save intermediate files to avoid memory overload
        for context_type in context_results:
            batch_path = os.path.join(
                output_dir,
                f"{model_name}_{base_embed_model}_batch_{batch_start}_{batch_end}_context_{context_type}.pt"
            )
            torch.save(context_results[context_type], batch_path)
            context_results[context_type].clear()

    # Merge per context type
    for context_type, _ in context_types:
        print(f"[Merge] Merging context type: {context_type}")
        merged = []
        for f in sorted(os.listdir(output_dir)):
            if f.startswith(f"{model_name}_{base_embed_model}_batch_") and f.endswith(f"context_{context_type}.pt"):
                data = torch.load(os.path.join(output_dir, f), weights_only=True)
                merged.extend(data)
                os.remove(os.path.join(output_dir, f))

        final_path = os.path.join(output_dir, f"{train_data}{model_name}_{base_embed_model}_context_{context_type}.pt")
        torch.save(merged, final_path)
        print(f"[Saved] {final_path} — {len(merged)} records")

    print(f"[Info] Fallback count: {fallback_counter}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model", type=str, required=True)
    parser.add_argument("--base_embed_model", type=str, default="all-MiniLM-L6-v2")
    parser.add_argument("--train_data", type=str, default="")
    parser.add_argument("--sample_num", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=16)
    args = parser.parse_args()

    if args.base_model in MODEL_PATH_DICT.keys():
        args.model_path = MODEL_PATH_DICT[args.base_model]
    else:
        raise ValueError(f"Model {args.base_model} not supported")

    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    model = AutoModelForCausalLM.from_pretrained(args.model_path)

    precompute_space(
        model_name=args.base_model,
        model=model,
        tokenizer=tokenizer,
        base_embed_model=args.base_embed_model,
        train_data=args.train_data,
        num_train=args.sample_num,
        output_dir=params.output_dir,
        batch_size=args.batch_size,
        compute_context_types=["chunk_8"],
        skip_if_exists=True, # True for next run
    )
