from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
import re
import torch
import random
import numpy as np
import os
import csv
import tqdm
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import pandas as pd
from run import zero_shot
import gc
import time
import copy
import json
# === CONFIGURATION ===
SEED = 42  # Or any integer

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # For full determinism (may slow down and not support all ops)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(SEED)

# def load_to_gpu(model):
#     model.to("cuda")
#     torch.cuda.empty_cache()

# def unload_to_cpu(model):
#     model.to("cpu")
#     torch.cuda.empty_cache()


# === LOAD MODELS ===
def load_model_and_tokenizer(name, cache_dir="/workspace/hf"):
    tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir)
    model = AutoModelForCausalLM.from_pretrained(name, cache_dir=cache_dir, device_map="auto", torch_dtype=torch.bfloat16)
    model.eval()
    return tokenizer, model


# eval_tokenizers_models = [load_model_and_tokenizer(m) for m in evaluation_models]

# === GENERATE CANDIDATES ===


# def truncate_to_first_sentence(text):
#     """Truncate text to the first sentence (ending in . ! or ?)."""
#     match = re.search(r'(.+?[.!?])(\s|$)', text.strip())
#     return match.group(1).strip() if match else text.strip()

def generate_candidates(context, tokenizer, model, num_return_sequences, max_gen_tokens=128):
    set_seed(SEED)
    
    # Tokenize with attention_mask
    enc = tokenizer(context, return_tensors="pt")
    input_ids = enc["input_ids"]
    attention_mask = enc["attention_mask"]

    model_device = next(model.parameters()).device
    input_ids = input_ids.to(model_device)
    attention_mask = attention_mask.to(model_device)

    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,  # ✅ this is the fix
            max_new_tokens=max_gen_tokens,
            do_sample=True,
            top_k=50,
            num_return_sequences=num_return_sequences,
            pad_token_id=tokenizer.eos_token_id,
        )

    candidates = []
    for out in output:
        full_text = tokenizer.decode(out[input_ids.shape[-1]:], skip_special_tokens=True)
        # first_sentence = truncate_to_first_sentence(full_text)
        candidates.append(full_text)
    return candidates


# === COMPUTE PERPLEXITY ===
def compute_perplexities(context, candidates, tokenizer, model):
    """
    Compute perplexity per candidate using model's built-in loss,
    matching the single-candidate version exactly.
    """
    model_device = next(model.parameters()).device
    full_texts = [context + c for c in candidates]

    # Tokenize all at once
    encodings = tokenizer(
        full_texts,
        return_tensors="pt",
        padding=True,
        truncation=True
    )

    input_ids = encodings["input_ids"].to(model_device)
    attention_mask = encodings["attention_mask"].to(model_device)

    perplexities = []
    with torch.no_grad():
        for i in range(len(candidates)):
            input_i = input_ids[i].unsqueeze(0)  # (1, seq_len)
            attn_i = attention_mask[i].unsqueeze(0)
            outputs = model(input_ids=input_i, attention_mask=attn_i, labels=input_i)
            loss = outputs.loss  # averaged over non-masked tokens
            perplexities.append(torch.exp(loss).item())

    return perplexities


def iterative_generate(context, generation_models, gen_tokenizers_models, eval_tokenizers_models, max_total_tokens=500, num_candidates=3):
    iteration_count = 0
    contexts = {}
    total_tokens = 0
    chosen_per_iteration = []

    while True:
        start_time = time.time()

        if iteration_count > 0:
            if total_tokens >= max_total_tokens:
                print("\n✅ Reached max length.")
                break
            if "</think>" in contexts[0]:
                print("\n✅ Reached end of thought.")
                break

        all_candidates = []
        candidates_per_model = {}

        for model_index, (tokenizer, model) in enumerate(gen_tokenizers_models):
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

            if iteration_count == 0:
                context_in_chat_temp = tokenizer.apply_chat_template(
                    context, tokenize=False, add_generation_prompt=True
                ) + "<think>"
                max_total_tokens = len(tokenizer.encode(context_in_chat_temp, add_special_tokens=True)) + 4096
                contexts[model_index] = context_in_chat_temp
                total_tokens = len(tokenizer.encode(contexts[model_index]))

            candidates = generate_candidates(
                contexts[model_index], tokenizer, model, num_candidates
            )
            print(f"\nMODEL: {generation_models[model_index]}")
            print("CANDIDATES:", candidates)

            all_candidates.extend(candidates)
            candidates_per_model[model_index] = candidates

        # === Batched Perplexity Evaluation ===
        scored_candidates_dict = {cand: [] for cand in all_candidates}
        for model_index, (tokenizer, model) in enumerate(eval_tokenizers_models):
            try:
                ppls = compute_perplexities(contexts[model_index], all_candidates, tokenizer, model)
                for i, cand in enumerate(all_candidates):
                    scored_candidates_dict[cand].append(ppls[i])
            except Exception as e:
                print(f"⚠️ Error in batched scoring: {e}")
                continue

        # === Aggregate scores ===
        scored_candidates = []
        for cand, ppl_list in scored_candidates_dict.items():
            if ppl_list:
                avg_ppl = sum(ppl_list) / len(ppl_list)
                scored_candidates.append((cand, avg_ppl))

        if not scored_candidates:
            print("❌ No valid candidates scored. Exiting.")
            break

        best_candidate, best_score = min(scored_candidates, key=lambda x: x[1])
        print(f"\n🔹 Selected: {best_candidate.strip()}")

        chosen_model_index = None
        for model_index, cands in candidates_per_model.items():
            if best_candidate in cands:
                chosen_model_index = model_index
                break

        chosen_per_iteration.append({
            "iteration": iteration_count,
            "selected_candidate": best_candidate,
            "selected_model_index": chosen_model_index,
            "all_candidates": copy.deepcopy(candidates_per_model),
            "score": best_score,
        })

        # Append selected to each context
        for key in contexts:
            contexts[key] += " " + best_candidate.strip()

        iteration_count += 1
        elapsed = time.time() - start_time
        print(f"⏱️ Iteration {iteration_count} completed in {elapsed:.2f} seconds.")

    return contexts, chosen_per_iteration



def main():
    import argparse

    parser = argparse.ArgumentParser(description="Context-aware sentence merging using perplexity.")
    # parser.add_argument("jsonl_files", nargs='+', help="Paths to input JSONL files.")
    parser.add_argument("--output", type=str, default="ensemble_outputs/merged_output.csv", help="Path to save merged output.")
    parser.add_argument("--models", nargs='+', default=[
        "BytedTsinghua-SIA/DAPO-Qwen-32B",
        "Qwen/QwQ-32B"
    ], help="Hugging Face model names.")
    set_seed(SEED)
    args = parser.parse_args()
    device = 'auto'
    generation_models = evaluation_models = args.models
    gen_tokenizers_models = eval_tokenizers_models = [load_model_and_tokenizer(m) for m in generation_models]
    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    write_header = not os.path.isfile(args.output)
    df = pd.read_csv("../dataset/test_data.csv")
    df = df.sample(n=100, random_state=42)

    with open(args.output, 'a', newline='', encoding='utf-8') as csvfile:
        fieldnames = ["Row Number", "Note ID", "Calculator ID", "Question", "Patient Note", "Ensembled Thought"]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames, quoting=csv.QUOTE_ALL)
    
        if write_header:
            writer.writeheader()
        for index in tqdm.tqdm(range(len(df))):
            row = df.iloc[index]
    
            patient_note = row["Patient Note"]
            question = row["Question"]
            calculator_id = str(row["Calculator ID"])
            note_id = str(row["Note ID"])
    
            # Create messages
            system, user = zero_shot(patient_note, question)
            messages = [
                {"role": "system", "content": system},
                {"role": "user", "content": user}
            ]
    
            # Generate output
            final_output, selected_distributions = iterative_generate(
                messages,
                generation_models,
                gen_tokenizers_models,
                eval_tokenizers_models
            )
            print(f"🔍 Iteration-wise Selection Summary:")
            for d in selected_distributions:
                print(f"Iteration {d['iteration']}: Model {d['selected_model_index']} -> {d['selected_candidate']}")
            with open(f"ensemble_outputs/selected_distributions_row_{index}.json", "w") as f:
                json.dump(selected_distributions, f, indent=2)
            # Write a single row to the CSV
            writer.writerow({
                "Row Number": int(row["Row Number"]),
                "Note ID": note_id,
                "Calculator ID": calculator_id,
                "Question": question,
                "Patient Note": patient_note,
                "Ensembled Thought": final_output[0].strip()
            })
            csvfile.flush()
            # after writing the row:
            gc.collect()
            torch.cuda.empty_cache()
    
            print(f"✅ Done. Row-wise output saved to: {args.output}")

if __name__ == "__main__":
    set_seed(SEED)
    device = "auto"
    # === RUN ===
    main()
