import os
import sys
import json
import torch
import numpy as np
import argparse
import logging
from pathlib import Path
from typing import List, Dict, Any
from tqdm import tqdm
from datasets import load_from_disk, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from itertools import product

from utils.eval_utils import (
    calculate_perplexity, 
    find_repeated_substrings, 
    generate_text_with_processor,
    _transform_results_to_new_format
)
from utils.analysis_utils import run_analysis_and_plot
from utils.bootstrap_mauve import bootstrap_mauve_from_file, plot_mauve_comparison_lines

logger = logging.getLogger(__name__)

from logits_processors.primal_threshold_processor_fix_k import PrimalThresholdProcessor
from transformers import (
    TopKLogitsWarper,
    LogitsProcessorList
)

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate perplexity of texts generated by different sampling methods")
    parser.add_argument("--model_choice", type=str, default="llama3", choices=["llama3", "gpt2"], help="Choose the model to evaluate: 'llama3' or 'gpt2'")
    parser.add_argument("--model_name", type=str, default=None, help="HuggingFace model name or path (overrides model_choice if set)")
    parser.add_argument("--top_k", type=int, nargs='+', default=[], help="List of Top-k values for TopK sampling method")
    parser.add_argument("--primal_alpha", type=float, nargs='+', default=[], help="List of Alpha values for Primal Threshold sampling")
    parser.add_argument("--primal_k_max", type=int, nargs='+', default=[], help="List of Maximum k values for Primal Threshold sampling")
    parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling (0 means greedy decoding)")
    parser.add_argument("--prefix_length", type=int, default=35, help="Number of tokens to use as prefix from the original text")
    parser.add_argument("--max_length", type=int, default=256, help="Maximum length of generated text (including prefix)")
    parser.add_argument("--num_samples", type=int, default=5000, help="Number of samples to use from the test dataset")
    parser.add_argument("--output_path", type=str, default=".", help="Directory where evaluation result files (texts_results_...json, metrics_results_...json) will be saved.")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
    parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
    parser.add_argument("--log_level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level")
    parser.add_argument("--quantization", type=str, choices=["none", "4bit", "8bit"], default="8bit", help="Quantization method to use for loading the model")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size for perplexity calculation")
    
    return parser.parse_args()


def evaluate_all(args):
    logging_level = getattr(logging, args.log_level)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging_level,
    )
    
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    device = f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu"
    
    if args.model_name:
        model_name_to_load = args.model_name
        logger.info(f"Using explicitly provided model name: {model_name_to_load}")
    elif args.model_choice == "llama3":
        model_name_to_load = "meta-llama/Meta-Llama-3.1-8B-Instruct"
        logger.info(f"Using selected model: Llama 3.1 8B Instruct")
    elif args.model_choice == "gpt2":
        model_name_to_load = "gpt2-large"
        logger.info(f"Using selected model: GPT-2 Large")
    else:
        logger.error(f"Invalid model choice: {args.model_choice}. Defaulting to Llama 3.")
        model_name_to_load = "meta-llama/Meta-Llama-3.1-8B-Instruct"

    model_name_simple = model_name_to_load.split('/')[-1].replace('-', '_') 

    if os.environ.get("HF_TOKEN"):
        logger.info(f"HF_TOKEN environment variable is set. Transformers should use it automatically.")
    else:
        logger.warning(f"HF_TOKEN environment variable is NOT set. Loading private models will likely fail.")
        
    quantization_config = None
    if args.model_choice != "gpt2" and args.quantization != "none":
        if args.quantization == "4bit":
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_quant_type="nf4"
            )
            logger.info("Applying 4-bit quantization.")
        elif args.quantization == "8bit":
            quantization_config = BitsAndBytesConfig(
                load_in_8bit=True,
                bnb_8bit_compute_dtype=torch.float32
            )
            logger.info("Applying 8-bit quantization.")
    elif args.model_choice == "gpt2":
        logger.info("Skipping quantization for GPT-2 model.")
    else:
        logger.info("Quantization explicitly set to 'none'.")

    logger.info(f"Loading model: {model_name_to_load}")
    model = AutoModelForCausalLM.from_pretrained(
        model_name_to_load, 
        device_map=device,
        trust_remote_code=True,
        quantization_config=quantization_config,
        torch_dtype=torch.float32
    )
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_name_to_load)
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        model.resize_token_embeddings(len(tokenizer))
    
    processors = []
    all_processor_names = []

    sampling_temperature = args.temperature

    for k in args.top_k:
        if k > 0:
            processor_name = f"top_k_{k}"
            top_k_processor = TopKLogitsWarper(top_k=k)
            processors.append((top_k_processor, processor_name, sampling_temperature))
            all_processor_names.append(processor_name)
            logger.info(f"Including processor: {processor_name}")

    if args.primal_alpha and args.primal_k_max:
        for alpha, k_max in product(args.primal_alpha, args.primal_k_max):
            processor_name = f"primal_alpha_{alpha}_k_{k_max}"
            primal_threshold_processor = PrimalThresholdProcessor(
                alpha=alpha,
                k_max=k_max,
                device=device
            )
            processors.append((primal_threshold_processor, processor_name, sampling_temperature))
            all_processor_names.append(processor_name)
            logger.info(f"Including processor: {processor_name}")
    elif args.primal_alpha or args.primal_k_max:
        logger.warning("Both --primal_alpha and --primal_k_max must be provided to run Primal Threshold sampling.")

    if not processors:
        logger.error("No generation methods selected. Exiting.")
        sys.exit(1)

    logger.info(f"Total processors to evaluate: {len(processors)}")

    logger.info("Loading webtext dataset (validation split)")
    webtext_dataset_dict = load_dataset('json', data_files='webtext.test.jsonl')
    webtext_dataset = webtext_dataset_dict['train'] 
    
    if args.num_samples > 0:
        webtext_dataset = webtext_dataset.select(range(min(args.num_samples, len(webtext_dataset))))
    
    results = []
    batch_size = args.batch_size

    for batch_start in tqdm(range(0, len(webtext_dataset), batch_size), desc="Processing batches"):
        batch_end = min(batch_start + batch_size, len(webtext_dataset))
        batch_samples = webtext_dataset.select(range(batch_start, batch_end))
        batch_texts = batch_samples["text"]
        batch_results = []
        
        all_prefix_ids = []
        valid_indices = []
        
        for i, text in enumerate(batch_texts):
            global_idx = batch_start + i
            
            tokens = tokenizer(
                text, 
                return_tensors="pt", 
                truncation=True, 
                max_length=args.max_length
            )
            input_ids = tokens.input_ids
            
            if input_ids.shape[1] <= args.prefix_length:
                logger.warning(f"Sample {global_idx} is too short, skipping")
                continue
                
            prefix_ids = input_ids[:, :args.prefix_length]
            prefix_text = tokenizer.decode(prefix_ids[0], skip_special_tokens=True)
            
            sample_result = {
                "sample_id": global_idx,
                "prefix": prefix_text,
                "original_text": tokenizer.decode(input_ids[0], skip_special_tokens=True),
                "generations": {}
            }
            
            all_prefix_ids.append(prefix_ids)
            valid_indices.append(i)
            batch_results.append(sample_result)
        
        if not batch_results:
            continue
        
        original_texts = [result["original_text"] for result in batch_results]
        
        original_perplexities = calculate_perplexity(
            model, 
            tokenizer, 
            original_texts,
            device,
            only_generated=True,
            prefix_length=args.prefix_length,
            max_length=args.max_length
        )
        
        for i, result in enumerate(batch_results):
            result["original_perplexity"] = original_perplexities[i]
            try:
                tokens = tokenizer(result["original_text"], truncation=True, max_length=args.max_length)
                result["original_has_repetition"] = find_repeated_substrings(tokens["input_ids"])
            except Exception as e:
                logger.error(f"Error checking repetition for original sample {result['sample_id']}: {e}")
                result["original_has_repetition"] = None

        for processor_info in processors:
            processor, processor_name, temperature = processor_info
            current_model = model
            
            stacked_prefix_ids = torch.cat(all_prefix_ids, dim=0)
            
            try:
                generated_texts = generate_text_with_processor(
                    current_model,
                    tokenizer,
                    stacked_prefix_ids,
                    args.max_length,
                    processor,
                    temperature,
                    device,
                    batch_size=batch_size
                )
                
                for i, result in enumerate(batch_results):
                    result["generations"][processor_name] = {"text": generated_texts[i]}
                
            except Exception as e:
                logger.error(f"Error during batch generation for {processor_name}: {str(e)}")
                generated_texts = []
                for i, result in enumerate(batch_results):
                    try:
                        generated_text = generate_text_with_processor(
                            current_model,
                            tokenizer,
                            all_prefix_ids[i],
                            args.max_length,
                            processor,
                            temperature,
                            device
                        )
                        result["generations"][processor_name] = {"text": generated_text}
                        generated_texts.append(generated_text)
                    except Exception as inner_e:
                        logger.error(f"Error generating text for sample {result['sample_id']}: {str(inner_e)}")
                        result["generations"][processor_name] = {"text": ""}
                        generated_texts.append("")
            
            try:
                texts_for_perplexity = [r["generations"][processor_name]["text"] for r in batch_results]
                
                generated_perplexities = calculate_perplexity(
                    model, 
                    tokenizer, 
                    texts_for_perplexity,
                    device,
                    only_generated=True,
                    prefix_length=args.prefix_length,
                    max_length=args.max_length
                )
                
                for i, result in enumerate(batch_results):
                    result["generations"][processor_name]["perplexity"] = generated_perplexities[i]
                    
                    if args.verbose:
                        logger.info(f"Sample {result['sample_id']}, Method {processor_name}, "
                                  f"Perplexity: {generated_perplexities[i]:.4f}")
            except Exception as e:
                logger.error(f"Error calculating perplexities for {processor_name}: {e}")
                for result in batch_results:
                    result["generations"][processor_name]["perplexity"] = float('nan')

            try:
                texts_for_repetition = [r["generations"][processor_name]["text"] for r in batch_results]
                tokenized_generations = tokenizer(
                    texts_for_repetition, 
                    truncation=True, 
                    max_length=args.max_length, 
                    padding=False
                )
                
                for i, result in enumerate(batch_results):
                    if processor_name in result["generations"]:
                        try:
                            has_rep = find_repeated_substrings(tokenized_generations["input_ids"][i])
                            result["generations"][processor_name]["has_repetition"] = has_rep
                        except Exception as e:
                            logger.error(f"Error checking repetition for sample {result['sample_id']}, method {processor_name}: {e}")
                            result["generations"][processor_name]["has_repetition"] = None
            except Exception as e:
                logger.error(f"Error tokenizing for repetition check: {e}")
                for result in batch_results:
                    if processor_name in result["generations"]:
                        result["generations"][processor_name]["has_repetition"] = None

        for result in batch_results:
            results.append(result)
        
        if (batch_end) % 10 == 0 or batch_end == len(webtext_dataset):

            current_texts_data, current_metrics_data = _transform_results_to_new_format(results, all_processor_names)
            
            output_dir = Path(args.output_path)
            texts_file_path = output_dir / f"texts_results_{model_name_simple}.json"
            metrics_file_path = output_dir / f"metrics_results_{model_name_simple}.json"

            try:
                with open(texts_file_path, 'w') as f:
                    json.dump(current_texts_data, f, indent=2)
                with open(metrics_file_path, 'w') as f:
                    json.dump(current_metrics_data, f, indent=2)
                logger.info(f"Periodically saved partial results to {texts_file_path} and {metrics_file_path}")
            except Exception as e:
                logger.error(f"Error saving partial results: {e}")

    final_texts_data, final_metrics_data = _transform_results_to_new_format(results, all_processor_names)

    output_dir = Path(args.output_path)
    texts_file_path = output_dir / f"texts_results_{model_name_simple}.json"
    metrics_file_path = output_dir / f"metrics_results_{model_name_simple}.json"

    try:
        with open(texts_file_path, 'w') as f:
            json.dump(final_texts_data, f, indent=2)
        logger.info(f"Final texts data saved to {texts_file_path}")

        with open(metrics_file_path, 'w') as f:
            json.dump(final_metrics_data, f, indent=2)
        logger.info(f"Final metrics data saved to {metrics_file_path}")
    except Exception as e:
        logger.error(f"Error saving final results: {e}")

    logger.info("\nAggregate Results:")
    print("-" * 100)
    print(f"Model: {model_name_to_load}")
    print(f"Temperature (for sampling methods): {args.temperature}")
    print(f"Prefix Length: {args.prefix_length}, Max Length: {args.max_length}")
    print(f"Number of Samples Evaluated: {len(results)}")
    print("-" * 100)
    print(f"{'Method':<35} {'Avg Perplexity':<15} {'Rep Freq (%)':<15}")
    print("-" * 100)

    original_perplexities_list = final_metrics_data["perplexity_data"].get("original", [])
    original_perplexities = [p for p in original_perplexities_list if p is not None and not np.isnan(p)]
    
    original_repetitions_list = final_metrics_data["repetition_data"].get("original", [])
    original_repetitions = [r for r in original_repetitions_list if r is not None]
    
    if original_perplexities:
        avg_original = np.mean(original_perplexities)
        rep_freq = (sum(original_repetitions) / len(original_repetitions)) * 100 if original_repetitions else 0
        print(f"{'original (ground truth)':<35} {avg_original:<15.4f} {rep_freq:<15.2f}")
    else:
        print(f"{'original (ground truth)':<35} {'N/A':<15} {'N/A':<15}")

    for processor_name in all_processor_names:
        perplexities = []
        repetitions = []

        current_perplexities_list = final_metrics_data["perplexity_data"].get(processor_name, [])
        perplexities = [p for p in current_perplexities_list if p is not None and not np.isnan(p)]

        current_repetitions_list = final_metrics_data["repetition_data"].get(processor_name, [])
        repetitions = [r for r in current_repetitions_list if r is not None]

        avg_perplexity_str = "NaN" 
        rep_freq_str = "NaN"

        if perplexities:
            avg_perplexity_str = f"{np.mean(perplexities):.4f}"
        
        if repetitions:
            rep_freq = (sum(repetitions) / len(repetitions)) * 100
            rep_freq_str = f"{rep_freq:.2f}"

        print(f"{processor_name:<35} {avg_perplexity_str:<15} {rep_freq_str:<15}")

    print("-" * 100)
    logger.info(f"Results saved to {texts_file_path} and {metrics_file_path}")

    logger.info(f"Starting data analysis and plotting for {metrics_file_path}")
    run_analysis_and_plot(str(metrics_file_path), model_name_simple)
    logger.info(f"Finished data analysis and plotting.")

    logger.info(f"\nCalculating MAUVE scores (direct and bootstrapped) using texts from: {texts_file_path}")
    mauve_results = bootstrap_mauve_from_file(
        data_file_path=str(texts_file_path),
        num_bootstraps=50,
        gpu_id=args.gpu_id,
        max_text_length=args.max_length,
        batch_size=args.batch_size
    )

    if mauve_results:
        print("\nMAUVE Scores and Bootstrap Statistics (vs Original):")
        print("-" * 100)
        print(f"{'Method':<35} {'Direct MAUVE':<15} {'Bootstrap Std':<15}")
        print("-" * 100)
        for method, stats in mauve_results.items():
            direct_score_str = f"{stats['direct_mauve_score']:.4f}" if isinstance(stats['direct_mauve_score'], float) and not np.isnan(stats['direct_mauve_score']) else "N/A"
            std_bs_score_str = f"{stats['std_bootstrap_score']:.4f}" if isinstance(stats['std_bootstrap_score'], float) and not np.isnan(stats['std_bootstrap_score']) else "N/A"
            print(f"{method:<35} {direct_score_str:<15} {std_bs_score_str:<15}")
        print("-" * 100)
        
        plot_output_dir = Path(args.output_path)
        plot_mauve_comparison_lines(mauve_results, model_name_simple, plot_output_dir)
    else:
        logger.warning("Could not calculate MAUVE scores.")


if __name__ == "__main__":
    args = parse_args()
    evaluate_all(args)
