from concurrent.futures import ThreadPoolExecutor
import torch
import pandas as pd
from PIL import Image, ImageFile
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode
import os
import gc
# Allow loading truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# === Constants ===
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMAGE_SIZE = 448
BATCH_SIZE = 1
DEVICE = "cuda:0"
MODEL_NAME = "OpenGVLab/InternVL3-38B"
CACHE_DIR = os.path.join(os.environ.get("SCRATCH"), ".cache")
QUESTION = "Is this image real or AI-generated?"
GENERATION_CONFIG = {
    "do_sample": False,
    "max_new_tokens": 10,
    "temperature": 0.0,
    "top_k": 1,
    "top_p": 1.0,
    "repetition_penalty": 1.0
}

def build_transform(input_size):
    return T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])

def load_image(image_path, transform):
    try:
        image = Image.open(image_path).convert("RGB")
        resized_image = image.resize((IMAGE_SIZE, IMAGE_SIZE), Image.BILINEAR)
        pixel_values = transform(resized_image).unsqueeze(0)
        return pixel_values
    except Exception as e:
        print(f"Error loading {image_path}: {e}")
        return None

def load_model():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    model = AutoModel.from_pretrained(
        MODEL_NAME,
        cache_dir=CACHE_DIR,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        use_flash_attn=True,
        load_in_8bit=True,
    )
    
    return model.eval(), tokenizer

def predict_batch(model, tokenizer, pixel_batch):
    responses = []
    for pixel_values in pixel_batch:
        try:
            pixel_values = pixel_values.to(dtype=torch.bfloat16, device=DEVICE)
            response = model.chat(tokenizer, pixel_values, QUESTION, GENERATION_CONFIG)
            responses.append(response.strip())
        except Exception as e:
            print(f"Error during prediction: {e}")
            responses.append("error")
    return responses

def run_inference_on_csv(in_csv, output_csv, save_every_n_batches=10):
    df = pd.read_csv(in_csv)
    assert 'file_name' in df.columns, "CSV must have a 'file_name' column"

    already_done = set()
    if os.path.exists(output_csv):
        existing_df = pd.read_csv(output_csv)
        already_done = set(existing_df['file_name'].tolist())
        print(f"Found {len(already_done)} already processed images.")
    else:
        existing_df = pd.DataFrame(columns=['file_name', 'label', 'prediction'])

    df = df[~df['file_name'].isin(already_done)]
    print(f"Processing {len(df)} remaining images...")

    if df.empty:
        print("All images already processed.")
        return

    model, tokenizer = load_model()
    transform = build_transform(IMAGE_SIZE)

    batch_counter = 0

    for i, row in tqdm(df.iterrows(), total=len(df), desc="Processing"):
        pixel = load_image(os.path.join(os.environ.get("SCRATCH", ""), "OpenFake", row['file_name']), transform)
        if pixel is None:
            result = (row['file_name'], row.get('model', 'unknown'), "error")
            existing_df = pd.concat([existing_df, pd.DataFrame([result], columns=['file_name', 'label', 'prediction'])])
            existing_df.to_csv(output_csv, index=False)
            continue

        # Run prediction safely
        try:
            with torch.no_grad():
                pixel = pixel.to(dtype=torch.bfloat16, device=DEVICE)
                response = model.chat(tokenizer, pixel, QUESTION, GENERATION_CONFIG).strip()
        except Exception as e:
            print(f"Error during prediction: {e}")
            response = "error"

        result = (row['file_name'], row.get('model'), row.get('label'), response)
        existing_df = pd.concat([existing_df, pd.DataFrame([result], columns=['file_name', 'model', 'label', 'prediction'])])
        batch_counter += 1

        # Save progress periodically
        if batch_counter % save_every_n_batches == 0:
            existing_df.to_csv(output_csv, index=False)
            print(f"Saved progress after {batch_counter} images.")

        # Free up memory
        del pixel
        torch.cuda.empty_cache()
        gc.collect()

    # Final save
    existing_df.to_csv(output_csv, index=False)
    print(f"\nDone! Final results saved to: {output_csv}")


if __name__ == "__main__":
    INPUT_CSV = os.path.join(os.environ.get("SCRATCH", ""), "OpenFake", "test", "metadata.csv")
    OUTPUT_CSV = "vlm_baseline_output_predictions.csv"
    run_inference_on_csv(INPUT_CSV, OUTPUT_CSV)
    
    result_df = pd.read_csv(OUTPUT_CSV)
    
    # ─── Compute metrics ────────────────────────────────────────────────────────────
    from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, average_precision_score

    def to_int_label(x):
        """Map various ground-truth formats to {0,1}."""
        if isinstance(x, (int, float)):
            return int(x)
        x = str(x).strip().lower()
        if x in {"real", "0", "false"}:
            return 0
        if x in {"fake", "ai", "1", "true", "ai-generated", "ai generated"}:
            return 1
        raise ValueError(f"Un-recognised label: {x}")

    def to_int_pred(x):
        """Heuristic mapping of free-text VLM output."""
        x = str(x).strip().lower()
        if "real" in x and "ai" not in x:
            return 0
        if "ai" in x or "fake" in x or "generated" in x:
            return 1
        return None          # unknown → drop

    # normalise labels / predictions
    result_df["label_int"] = result_df["label"].apply(to_int_label)
    result_df["pred_int"]  = result_df["prediction"].apply(to_int_pred)

    # remove rows the heuristic could not parse
    clean_df = result_df.dropna(subset=["pred_int"]).copy()
    clean_df["pred_int"] = clean_df["pred_int"].astype(int)

    y_true  = clean_df["label_int"].values
    y_score = clean_df["pred_int"].values           # discrete; for AUC you normally want prob/logit
                                                    # but with hard labels roc_auc_score still works
    auc  = roc_auc_score(y_true, y_score)
    acc  = accuracy_score(y_true, y_score)
    f1   = f1_score(y_true, y_score)
    tpr = ((clean_df["label_int"] == 1) & (clean_df["pred_int"] == 1)).sum() / ((clean_df["label_int"] == 1).sum() or 1)
    auc_pr = average_precision_score(y_true, y_score)

    print("\n===== OVERALL METRICS =====")
    print(f"ROC-AUC : {auc:.4f}")
    print(f"Accuracy: {acc:.4f}")
    print(f"F1 score: {f1:.4f}")
    print(f"TPR     : {tpr:.4f}")
    print(f"AUC-PR  : {auc_pr:.4f}")

    print("\n===== PER-GENERATOR ACCURACY =====")
    for gen_name, grp in clean_df.groupby("model"):
        acc_gen = accuracy_score(grp["label_int"], grp["pred_int"])
        print(f"{gen_name:30s}  {acc_gen:.4f}  (n={len(grp)})")
