import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
from tqdm import tqdm

def compute_batch_energy_entropy(texts, model, tokenizer, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
    """
    Compute average energy and entropy for multiple texts in batch, return energy and entropy lists for each text.
    """
    texts = [tokenizer.bos_token + text for text in texts]

    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, add_special_tokens=True).to(device)
    input_ids = inputs["input_ids"]

    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        logits = outputs.logits[:, :-1, :]
        labels = input_ids[:, 1:]
        
        attention_mask = inputs["attention_mask"][:, 1:]
        token_counts = attention_mask.sum(dim=1)


        # calculate energy
        selected_logits = logits.gather(2, labels.unsqueeze(-1)).squeeze(-1)
        selected_logits = selected_logits * attention_mask
        energies = (-selected_logits.sum(dim=1) / token_counts).cpu().tolist()

        # calculate entropy
        probs = torch.softmax(logits, dim=-1)
        log_probs = torch.log(probs + 1e-12)  # Prevent log(0)
        entropies = (-(probs * log_probs).sum(dim=-1) * attention_mask).sum(dim=1)
        entropies = (entropies / token_counts).cpu().tolist()

    return energies, entropies

def process_csv(input_csv, output_csv, only_refined=False, model_name="gpt2", device=None, batch_size=16):

    df = pd.read_csv(input_csv)
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model.eval()

    def batch_compute(texts):
        energies, entropies = [], []
        for i in tqdm(range(0, len(texts), batch_size), desc="Computing energy and entropy"):
            batch = texts[i:i+batch_size]
            batch_E, batch_H = compute_batch_energy_entropy(batch, model, tokenizer, device)
            # Find abnormal values and print
            for idx, (e, h) in enumerate(zip(batch_E, batch_H)):
                if not torch.isfinite(torch.tensor(e)) or torch.isnan(torch.tensor(h)):
                    print(f"Abnormal value: index={i+idx}, energy={e}, entropy={h}, text={batch[idx]}")
            energies.extend(batch_E)
            entropies.extend(batch_H)
        return energies, entropies

    if not only_refined:
        input_texts = df["input_text"].astype(str).tolist()
        input_E, input_H = batch_compute(input_texts)
        df["input_energy"] = input_E
        df["input_entropy"] = input_H
    print(f"input_text average energy: {sum(input_E)/len(input_E):.4f}, average entropy: {sum(input_H)/len(input_H):.4f}")

    refined_texts = df["refined_text"].astype(str).tolist()
    refined_E, refined_H = batch_compute(refined_texts)
    df["refined_energy"] = refined_E
    df["refined_entropy"] = refined_H
    print(f"refined_text average energy: {sum(refined_E)/len(refined_E):.4f}, average entropy: {sum(refined_H)/len(refined_H):.4f}")

    df.to_csv(output_csv, index=False)
    print(f"Results saved to {output_csv}")

def debug_from_csv(output_csv):
    df = pd.read_csv(output_csv)
    print("Rows with energy=inf or entropy=nan:")
    mask = (df["input_energy"].replace([float('inf'), -float('inf')], float('nan')).isna()) | \
        (df["input_entropy"].isna()) | \
        (df["refined_energy"].replace([float('inf'), -float('inf')], float('nan')).isna()) | \
        (df["refined_entropy"].isna())
    print(df[mask])

 # Example usage
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="calculate energy and entropy for texts in a CSV file")
    parser.add_argument("--input_csv", type=str, required=True, help="path to input CSV file")
    parser.add_argument("--output_csv", type=str, required=True, help="path to output CSV file")
    parser.add_argument("--only_refined", action="store_true", help="only compute for refined_text column")
    parser.add_argument("--model_name", type=str, default="gpt2", help="pretrained model name")
    parser.add_argument("--batch_size", type=int, default=16, help="batch size for processing texts")
    parser.add_argument("--debug", action="store_true", help="debug from output CSV")
    args = parser.parse_args()

    if args.debug:
        debug_from_csv(args.output_csv)
    else:
        process_csv(args.input_csv, args.output_csv, args.only_refined, args.model_name, batch_size=args.batch_size)
