import argparse
import os
import glob
import json
import torch
import wandb
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List
from tqdm import tqdm
from torch.utils.data import DataLoader # Import DataLoader
from math import exp                     # Import exp

# --- Basic Setup ---
device = "cuda" if torch.cuda.is_available() else "cpu"
max_length = 512 # Max length to use for NLL calculation

# === Argument Parser ===
def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate the RLHF-aligned model.")
    parser.add_argument(
        "--load_model_name", type=str, required=True,
        help="Experiment case name. The script will load the model from './results/{load_model_name}/checkpoint-1000'.",
    )
    parser.add_argument(
        "--pstar_model_dir", type=str, default="XXXX/SmolLM2-1.7B",
        help="Path or Hub name for the ground-truth model (p*)."
    )
    parser.add_argument(
        "--target_model_dir", type=str, default="./target_models/",
        help="Path to the target model for precision and reward calculation."
    )
    parser.add_argument(
        "--sample_size", type=int, default=10000, # Limit sample size for memory management
        help="Number of samples to use for evaluation from each dataset."
    )
    parser.add_argument(
        "--batch_size", type=int, default=512, # Batch size for reward and NLL computation
        help="Batch size for reward and log-likelihood computation."
    )
    parser.add_argument(
        "--seed", type=int, default=1,
        help="Random seed for reproducibility."
    )
    parser.add_argument(
        "--wandb_project", type=str, default="XXXX",
        help="W&B project name for logging evaluation results."
    )
    parser.add_argument(
        "--iteration", type=int, default=1000,
    )
    parser.add_argument("--kd_iteration", type=int, default=0)
    
    parser.add_argument("--type", type=str, default='pretrained')
    
    args = parser.parse_args()
    return args

# === Function to load prompt, response pairs from .json files ===
def load_prompts_responses_from_dir(directory, max_samples):
    """Loads prompts and responses from all .json files in a directory."""
    items = []
    files = sorted(glob.glob(os.path.join(directory, "*.json")))
    if not files:
        raise FileNotFoundError(f"No JSON files found in directory: {directory}")

    print(f"Loading data from: {files[0]}")
    with open(files[0], "r", encoding="utf-8") as f:
        data_list = json.load(f)
        for item in data_list:
            if "response" in item and "prompt" in item:
                items.append({"prompt": item["prompt"], "response": item["response"]})
            if len(items) >= max_samples:
                break
    return items

@torch.no_grad()
def compute_log_likelihood(model, tokenizer, texts, batch_size=8, tag=""):
    model.eval()
    total_log_likelihood = 0.0
    total_tokens = 0

    log_table = wandb.Table(columns=["text", "log_likelihood", "token_count"])
    
    # Use DataLoader for batch processing
    dataloader = DataLoader(texts, batch_size=batch_size)
    
    print(f"Computing log-likelihood for {len(texts)} samples...")
    for batch in tqdm(dataloader, desc=f"Computing LL ({tag})"):
        # Use tokenizer inside the function
        encodings = tokenizer(batch, return_tensors="pt", padding=True, truncation=True,
                                 max_length=max_length).to(device)

        input_ids = encodings["input_ids"]
        attention_mask = encodings["attention_mask"]

        # Compute per-token loss
        outputs = model(**encodings, labels=input_ids)
        logits = outputs.logits

        # Shift logits and labels for causal LM
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        shift_mask = attention_mask[:, 1:].contiguous()

        # Compute log-probs and select the label probabilities
        log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
        shift_log_probs = log_probs.gather(dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)

        # Masked sum of log-probs
        per_sample_log_likelihood = (shift_log_probs * shift_mask).sum(dim=1)  # [batch]
        token_counts = shift_mask.sum(dim=1)  # [batch]

        # Aggregate totals
        total_log_likelihood += per_sample_log_likelihood.sum().item()
        total_tokens += token_counts.sum().item()

        for text, ll, n_tok in zip(batch, per_sample_log_likelihood.tolist(), token_counts.tolist()):
            log_table.add_data(text, ll, n_tok)

    mean_log_likelihood = total_log_likelihood / total_tokens if total_tokens > 0 else float("-inf")
    perplexity = exp(-mean_log_likelihood)

    return mean_log_likelihood, perplexity, log_table

# === Reward Calculation Function (provided code) ===
@torch.no_grad()
def get_reward_signal_response_only(
    prompts: List[str],
    responses: List[str],
    reward_model,
    reward_tokenizer,
    device,
    scaling_factor: float,
    max_length: int = 512,
) -> torch.FloatTensor:
    reward_model.eval()
    inputs = reward_tokenizer([p + r for p, r in zip(prompts, responses)],
                                 return_tensors="pt", padding=True, truncation=True, max_length=max_length).to(device)
    input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]

    prompt_lens = [reward_tokenizer(p, return_tensors="pt")["input_ids"].size(1) for p in prompts]

    logits = reward_model(input_ids=input_ids, attention_mask=attention_mask).logits
    shift_logits, shift_labels = logits[:, :-1, :].contiguous(), input_ids[:, 1:].contiguous()
    shift_attn = attention_mask[:, 1:].contiguous()
    B, Tm1 = shift_labels.size()

    resp_mask = torch.zeros((B, Tm1), dtype=shift_attn.dtype, device=device)
    for i, pl in enumerate(prompt_lens):
        resp_mask[i, min(pl - 1, Tm1):] = 1 # Exclude the prompt part
    final_mask = (resp_mask * shift_attn).bool()

    loss_tok = torch.nn.functional.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        reduction="none"
    ).view(B, Tm1)

    denom = final_mask.sum(dim=1).clamp_min(1)
    nll_per_sample = (loss_tok.masked_fill(~final_mask, 0).sum(dim=1) / denom)
    nll_per_sample = torch.nan_to_num(nll_per_sample, nan=10.0)

    rewards = 10.0 * torch.exp(-scaling_factor * nll_per_sample)
    return rewards

# === Main Execution Function ===
def main():
    args = parse_args()
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    # --- Set Model and Data Paths ---
    pstar_data_dir = './generated_data/ground_truth_model/'

    if args.type == 'pretrained':
        
        pprime_model_dir = f'./pretrained_models/{args.load_model_name}/seed{args.seed}/'

        if '135' in args.load_model_name:
            pprime_data_dir = f'./generated_data/pretrained_model_seed{args.seed}/135m/'
        else:
            pprime_data_dir = f'./generated_data/pretrained_model_seed{args.seed}/360m/'

    else:

        if 'kd' not in args.load_model_name:
            pprime_model_dir = f'./results/{args.load_model_name}/checkpoint-{args.iteration}'
        else:
    
            if args.kd_iteration == 0:
    
                pprime_model_dir = f"results/{args.load_model_name}/checkpoint-epoch-1/"
    
            else:
    
                pprime_model_dir = f"results/{args.load_model_name}/checkpoint-{args.kd_iteration}/"
            
        pprime_data_dir = f'./generated_data/{args.load_model_name}_validation/'
        
    args.target_model_dir = args.target_model_dir + 'seed' + str(args.seed) + '/'

    # --- Initialize W&B ---
    wandb.login()
    wandb.init(
        project=args.wandb_project,
        name=f"eval-aligned-{args.load_model_name}-seed{args.seed}",
        config=vars(args),
        resume="allow"
    )
    print("--- W&B Initialized ---")

    # --- Load Tokenizer and Models ---
    print("--- Loading Tokenizer and Models ---")
    tokenizer = AutoTokenizer.from_pretrained("XXXX/SmolLM2-1.7B")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print(f"Loading Aligned Model (p') from: {pprime_model_dir}")
    model_p_prime = AutoModelForCausalLM.from_pretrained(pprime_model_dir).to(device).eval()
    
    print(f"Loading Ground-Truth Model (p*) from: {args.pstar_model_dir}")
    model_p_star = AutoModelForCausalLM.from_pretrained(args.pstar_model_dir).to(device).eval()

    print(f"Loading Target Model from: {args.target_model_dir}")
    target_model = AutoModelForCausalLM.from_pretrained(args.target_model_dir).to(device).eval()
    print("--- All models loaded successfully ---")
    
    # --- Load Data ---
    print("\n--- Loading Datasets for Evaluation ---")
    p_prime_samples = load_prompts_responses_from_dir(pprime_data_dir, args.sample_size)
    p_star_samples = load_prompts_responses_from_dir(pstar_data_dir, args.sample_size)
    
    p_prime_responses = [item["response"] for item in p_prime_samples]
    p_star_responses = [item["response"] for item in p_star_samples]
    print(f"Loaded {len(p_prime_responses)} samples from aligned model.")
    print(f"Loaded {len(p_star_responses)} samples from ground-truth model.")

    # === 1. Calculate Precision/Recall vs. Ground-Truth Model ===
    print("\n--- 1. Calculating Precision and Recall vs. Ground-Truth Model ---")
    
    # Calculate Recall
    recall_ll, recall_ppl, recall_table = compute_log_likelihood(
        model_p_prime, tokenizer, p_star_responses, args.batch_size, tag="Recall"
    )
    print(f"   => Recall (p' on p* data): Avg LL/token={recall_ll:.4f}, PPL={recall_ppl:.4f}")

    # Calculate Precision
    precision_ll, precision_ppl, precision_table = compute_log_likelihood(
        model_p_star, tokenizer, p_prime_responses, args.batch_size, tag="Precision"
    )
    print(f"   => Precision (p* on p' data): Avg LL/token={precision_ll:.4f}, PPL={precision_ppl:.4f}")

    # === 2. Calculate Precision vs. Target Model ===
    print("\n--- 2. Calculating Precision vs. Target Model ---")
    target_precision_ll, target_precision_ppl, target_precision_table = compute_log_likelihood(
        target_model, tokenizer, p_prime_responses, args.batch_size, tag="Target Precision"
    )
    print(f"   => Target Precision (target on p' data): Avg LL/token={target_precision_ll:.4f}, PPL={target_precision_ppl:.4f}")

    # === 3. Calculate Average Reward and Variance vs. Target Model ===
    print("\n--- 3. Calculating Average Reward and Variance vs. Target Model ---")
    all_rewards = []
    prompts = [item["prompt"] for item in p_prime_samples]
    responses = [item["response"] for item in p_prime_samples]

    for i in tqdm(range(0, len(prompts), args.batch_size)):
        batch_prompts = prompts[i:i+args.batch_size]
        batch_responses = responses[i:i+args.batch_size]
        
        rewards = get_reward_signal_response_only(
            prompts=batch_prompts,
            responses=batch_responses,
            reward_model=target_model,
            reward_tokenizer=tokenizer,
            device=device,
            scaling_factor=0.5
        )
        all_rewards.extend(rewards.cpu().tolist())

    avg_reward = np.mean(all_rewards)
    var_reward = np.var(all_rewards)
    print(f"   => Average Reward: {avg_reward:.4f}")
    print(f"   => Reward Variance: {var_reward:.4f}")
    
    # === 4. W&B Logging ===
    print("\n--- Logging all metrics to W&B ---")
    wandb.log({
        # Log original Recall/Precision as LL
        "gt_recall_ll_per_token": recall_ll,
        "gt_precision_ll_per_token": precision_ll,
        "target_precision_ll_per_token": target_precision_ll,
        
        # Add Perplexity
        "gt_recall_perplexity": recall_ppl,
        "gt_precision_perplexity": precision_ppl,
        "target_precision_perplexity": target_precision_ppl,

        # Reward-related metrics
        "target_avg_reward": avg_reward,
        "target_reward_variance": var_reward,
        
        # Add detailed tables
        "details/gt_recall": recall_table,
        "details/gt_precision": precision_table,
        "details/target_precision": target_precision_table,
    })
    
    wandb.finish()
    print("--- Evaluation Complete and Logged to W&B ---")

if __name__ == "__main__":
    main()
