import os
import json
import random
import torch
import pandas as pd
from tqdm import tqdm
import wandb
import numpy as np
from scipy import stats


def fv_icl_tasks_benchmark(
    model,
    tokenizer,
    task_name,
    task_dir="~pythia_replicate/dataset/icl_tasks",
    max_samples=5000,
    num_of_shots=5,
    batch_size=64,
    use_wandb=True,
    return_samples_used=False,
    seed=1234,
    return_per_sample=False,
):
    random.seed(seed)
    def generate_few_shot_prompts(data, num_of_shots=5, max_samples=None):
        few_shot_prompts = []
        few_shot_targets = []

        # Limit samples if specified
        if max_samples and len(data) > max_samples:
            data = random.sample(data, max_samples)
            samples_used = max_samples
        else:
            samples_used = len(data)
        if use_wandb:
            wandb.summary[f"samples_for_{task_name}"] = samples_used

        for point_idx, point in enumerate(data):
            prompt = "Q: " + point["input"] + "\nA:"
            target = point["output"]

            random.seed(seed + point_idx)
            # random.seed(point_idx)
            # Filter out current row_idx from available indices before sampling
            available_indices = [i for i in range(len(data)) if i != point_idx]

            # If we don't have enough examples for few-shot, skip this point
            if len(available_indices) < num_of_shots:
                print("Not enough data")
                continue

            sampled_prompts = [
                data[idx] for idx in random.sample(available_indices, num_of_shots)
            ]

            few_shot_prompt = ""
            for sample in sampled_prompts:
                few_shot_prompt += f"Q: {sample['input']}\nA: {sample['output']}\n\n"

            few_shot_prompt += prompt

            few_shot_prompts.append(few_shot_prompt)
            few_shot_targets.append(target)

        return few_shot_prompts, few_shot_targets, samples_used

    def process_batch(prompts_batch, targets_batch):
        """Process a batch of prompts and return predictions."""
        # Tokenize all prompts in the batch
        tokenized_batch = tokenizer(
            prompts_batch, return_tensors="pt", padding=True, truncation=True
        ).to(model.device)

        batch_predictions = []

        with torch.no_grad():
            # Get model output logits for all samples in the batch
            outputs = model(**tokenized_batch)
            logits = outputs.logits[:, -1, :]  # Last token logits for each sample

            for i, target in enumerate(targets_batch):
                predicted_token_id = torch.argmax(logits[i]).item()

                # Get target token IDs (both with and without leading space)
                target_ids = []

                # Target without leading space
                target_ids.append(tokenizer.encode(target, add_special_tokens=False)[0])
                # Target with leading space
                target_ids.append(
                    tokenizer.encode(" " + target, add_special_tokens=False)[0]
                )

                # Check if prediction matches either target encoding
                batch_predictions.append(predicted_token_id in target_ids)

        return batch_predictions

    # Check if task directory exists
    if not os.path.exists(task_dir):
        print(f"Warning: Task directory {task_dir} not found. Skipping ICL benchmark.")
        return {}

    """if few_shot:"""
    # For few-shot, we need the JSON file to construct prompts
    json_file = os.path.join(task_dir, f"{task_name}.json")
    if not os.path.exists(json_file):
        print(
            f"Warning: JSON file {json_file} not found for few-shot prompting. Skipping ICL benchmark."
        )
        return {}
    task_file = json_file

    task_accuracy = 0

    task_name = os.path.splitext(os.path.basename(task_file))[0]

    # Load task data
    try:
        """if few_shot:"""
        # Load JSON data for few-shot prompting
        with open(task_file, "r") as f:
            json_data = json.load(f)
        # Generate few-shot prompts
        few_shot_prompts, few_shot_targets, samples_used = generate_few_shot_prompts(
            json_data, num_of_shots, max_samples
        )
        data = pd.DataFrame({"prompt": few_shot_prompts, "target": few_shot_targets})
    except Exception as e:
        print(f"Error loading {task_file}: {e}")
        return None
    correct_predictions = 0
    total_predictions = 0
    per_sample_correct = []

    # Process data in batches
    prompts = data["prompt"].tolist()
    targets = data["target"].tolist()

    for i in tqdm(
        range(0, len(prompts), batch_size),
        desc=f"Processing {task_name}",
        leave=False,
    ):
        batch_prompts = prompts[i : i + batch_size]
        batch_targets = targets[i : i + batch_size]

        try:
            batch_predictions = process_batch(batch_prompts, batch_targets)
            correct_predictions += sum(batch_predictions)
            total_predictions += len(batch_predictions)

            per_sample_correct.extend(batch_predictions)
        except Exception as e:
            # Skip this batch if there's an error
            print(f"Error processing batch: {e}")
            continue

    # Calculate accuracy for this task
    if total_predictions > 0:
        accuracy = correct_predictions / total_predictions
    else:
        accuracy = 0.0

    if return_per_sample and return_samples_used:
        return accuracy, samples_used, per_sample_correct
    elif return_per_sample:
        return accuracy, per_sample_correct
    elif return_samples_used:
        return accuracy, samples_used
    else:
        return accuracy


def fv_icl_tasks_benchmark_with_ci(
    model,
    tokenizer,
    task_name,
    task_dir="~pythia_replicate/dataset/icl_tasks",
    max_samples=5000,
    num_of_shots=5,
    batch_size=64,
    use_wandb=False,
    confidence_level=0.95,
    return_per_sample=False,
):
    # Call your original function
    accuracy, n_samples, per_sample_correct = fv_icl_tasks_benchmark(
        model=model,
        tokenizer=tokenizer,
        task_name=task_name,
        task_dir=task_dir,
        max_samples=max_samples,
        num_of_shots=num_of_shots,
        batch_size=batch_size,
        use_wandb=use_wandb,
        return_samples_used=True,
        return_per_sample=return_per_sample,
    )

    z_value = stats.norm.ppf((1 + confidence_level) / 2)

    # Handle edge cases
    if accuracy == 0:
        ci_lower = 0
        ci_upper = min(1, 3.0 / n_samples)  # Rule of three
        std_error = 0
    elif accuracy == 1:
        ci_lower = max(0, 1 - 3.0 / n_samples)  # Rule of three
        ci_upper = 1
        std_error = 0
    else:
        std_error = np.sqrt((accuracy * (1 - accuracy)) / n_samples)
        margin = z_value * std_error
        ci_lower = max(0, accuracy - margin)
        ci_upper = min(1, accuracy + margin)

    return {
        "accuracy": accuracy,
        "ci_lower": ci_lower,
        "ci_upper": ci_upper,
        "std_error": std_error,
        "n_samples": n_samples,
    }, per_sample_correct
