"""
Entry point for RQ3: Quantize a model and run perplexity evaluation.
This script handles model loading, quantization, and real perplexity
calculation on a specified dataset.
"""
import argparse
import logging
import json
import os
import torch
from tqdm import tqdm
from itertools import islice
import time
import gc

# Add project root to Python path to allow for local package imports
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from benford_quant.utils import load_config, setup_logging
from benford_quant.models.wrapper import apply_quantization

def get_model_size_mb(model: torch.nn.Module) -> float:
    """Calcula tamanho do modelo em MB (memória para pesos)."""
    total_params = sum(p.numel() for p in model.parameters())
    bytes_per_param = next(model.parameters()).element_size()
    return total_params * bytes_per_param / (1024 ** 2)

# Import necessary libraries, assuming they are in requirements.txt
try:
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        pipeline,
        AutoModelForSequenceClassification,
        AutoModelForSeq2SeqLM,
        DataCollatorWithPadding
    )
    from datasets import load_dataset
    import evaluate
    import numpy as np
except ImportError:
    print("Please install the required libraries: pip install -r requirements.txt")
    sys.exit(1)

def evaluate_perplexity(
    model: torch.nn.Module,
    tokenizer: AutoTokenizer,
    dataset_name: str,
    dataset_subset: str,
    device: str,
    n_samples: int = 128,
    seq_len: int = 2048,
    batch_size: int = 8
) -> float:
    """
    Calculates perplexity GPT-2 style (Radford et al. 2019):
    - concat validation set texts
    - split in fixed size blocks (seq_len, e.g.: 2048)
    - computes perplexity over tokens
    """
    logger = logging.getLogger(__name__)
    logger.info(f"Loading full dataset '{dataset_name}' subset '{dataset_subset}' for perplexity evaluation.")

    # 1. Load dataset
    dataset = load_dataset(dataset_name, dataset_subset, split="validation")

    # 2. Concat all texts
    full_text = "\n\n".join(dataset["text"])

    # 3. Tokenize all
    encodings = tokenizer(full_text, return_tensors="pt")
    input_ids = encodings.input_ids[0]  # shape: [num_tokens]

    logger.info(f"Total tokens in validation set: {len(input_ids)}")

    # 4. Non overllaping blocks of seq_len
    n_tokens = len(input_ids)
    n_batches = n_tokens // seq_len
    input_ids = input_ids[: n_batches * seq_len]  # drop remainder
    input_ids = input_ids.view(n_batches, seq_len).to(device) # .to -> data in gpu

    total_loss = 0.0
    total_tokens = 0

    model.eval()
    with torch.no_grad():
        for i in tqdm(range(n_batches), desc="Perplexity (concat mode)"):
            batch = input_ids[i].unsqueeze(0).to(device)  # [1, seq_len]
            labels = batch.clone()

            outputs = model(batch, labels=labels)
            loss = outputs.loss.item()

            total_loss += loss * seq_len
            total_tokens += seq_len

    avg_neg_log_likelihood = total_loss / total_tokens
    ppl = torch.exp(torch.tensor(avg_neg_log_likelihood))
    return ppl.item()

def _get_sentence_keys(task_name):
    """Auxiliary function to return the sentence keys for each GLUE task."""
    task_to_keys = {
        "cola": ("sentence", None),
        "mnli": ("premise", "hypothesis"),
        "mrpc": ("sentence1", "sentence2"),
        "qnli": ("question", "sentence"),
        "qqp": ("question1", "question2"),
        "rte": ("sentence1", "sentence2"),
        "sst2": ("sentence", None),
        "stsb": ("sentence1", "sentence2"),
        "wnli": ("sentence1", "sentence2"),
    }
    return task_to_keys[task_name]

def evaluate_downstream_task(
    model: torch.nn.Module,
    tokenizer: AutoTokenizer,
    eval_config: dict,
    device: str
) -> dict:
    """
    Downstream task evaluation (GLUE, LAMBADA).

    Args:
        model: model to evaluate.
        tokenizer: corresponding tokenizer.
        eval_config: config dict for eval.
        device: execution device.

    Returns:
        a dict with the evaluation metrics.
    """
    logger = logging.getLogger(__name__)
    task_type = eval_config.get("type")
    batch_size = eval_config.get("batch_size", 8)

    logger.info(f"Running downstream evaluation for task type: {task_type}")

    if task_type == "glue":
        task_name = eval_config["dataset_subset"]
        logger.info(f"Evaluating GLUE task: {task_name}")

        # 1. Load dataset and metric
        dataset = load_dataset("glue", task_name)
        metric = evaluate.load("glue", task_name)

        # 2. Preprocess dataset
        sentence1_key, sentence2_key = _get_sentence_keys(task_name)

        def preprocess_function(examples):
            if sentence2_key is None:
                result = tokenizer(
                    examples[sentence1_key],
                    truncation=True,
                    padding="max_length"
                )
            else:
                result = tokenizer(
                    examples[sentence1_key],
                    examples[sentence2_key],
                    truncation=True,
                    padding="max_length"
                )
            
            result["labels"] = [int(l) for l in examples["label"]]
            return result
        
        tokenized_dataset = dataset.map(
            preprocess_function,
            batched=True,
            remove_columns=dataset["validation"].column_names
        )

        # 3. Data collator
        data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

        # 4. Evaluation loop
        model.eval()

        validation_dataset = tokenized_dataset["validation"]
        data_loader = torch.utils.data.DataLoader(
            validation_dataset,
            batch_size=batch_size,
            collate_fn=data_collator
        )

        all_predictions = []
        all_labels = []

        # Evaluation process
        for batch in tqdm(data_loader, desc=f"Evaluating {task_name}"):
            labels = batch["labels"].clone().detach().to(device)
            batch = {k: v.to(device) for k, v in batch.items() if k != 'labels'}

            with torch.no_grad():
                outputs = model(**batch)

            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1)

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        # 5. Compute metrics
        results = metric.compute(predictions=all_predictions, references=all_labels)
        logger.info(f"GLUE task '{task_name}' results: {results}")
        return results
    
    elif task_type == "lambada":
        logger.info("Evaluating on LAMBADA (last word prediction accuracy).")

        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = model.config.eos_token_id

        # 1. Load dataset
        dataset = load_dataset("lambada", 'plain_text', split="test")

        correct = 0
        total = 0
        model.eval()

        with torch.no_grad():
            for example in tqdm(dataset, desc="Evaluating LAMBADA"):
                words = example["text"].split()
                if len(words) < 2:  # short phrases
                    continue

                context = " ".join(words[:-1])
                target = words[-1]

                # 2. Tokenizes inputs
                inputs = tokenizer(context, return_tensors="pt").to(device)

                #model.config.use_cache = False # QWEN-MODEL: Uncomment this line for qwen models
                # 3. Generate only last token
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=1,
                    do_sample=False,
                    #use_cache=False, # QWEN-MODEL: Uncomment this line for qwen models
                )
                pred = tokenizer.decode(outputs[0], skip_special_tokens=True).split()[-1]

                if pred == target:
                    correct += 1
                total += 1

        accuracy = correct / total if total > 0 else 0.0
        results = {"accuracy": accuracy}
        logger.info(f"LAMBADA accuracy: {accuracy:.4f}")
        return results

    elif task_type == "summarization":
        logger.warning("Summarization evaluation is not yet implemented.")
        return {"status": "not_implemented"}

    elif task_type == "qa":
        logger.warning("QA evaluation is not yet implemented.")
        return {"status": "not_implemented"}

    else:
        raise ValueError(f"Unsupported downstream task type: {task_type}")


def main():
    """
    Main function to run quantization and evaluation.
    """
    parser = argparse.ArgumentParser(description="Quantize a model and run evaluations.")
    parser.add_argument('--config_path', type=str, required=True, help='Path to the YAML configuration file.')
    args = parser.parse_args()

    # 1. Setup
    setup_logging()
    logger = logging.getLogger(__name__)
    logger.info("Starting quantization and evaluation script...")

    # 2. Load config
    logger.info(f"Loading configuration from: {args.config_path}")
    config = load_config(args.config_path)

    # 3. Determine device and load tokenizer
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device: {device}")

    model_config = config['model']
    #local_model_dir = "/your/local/cache/dir"  # Change as needed

    tokenizer = AutoTokenizer.from_pretrained(
        model_config['name'],
        trust_remote_code=model_config.get('trust_remote_code', True),
        #cache_dir=local_model_dir,
        device_map='auto' # CPU-QUANT: Comment out this line
    )

    # 4. Evaluate the quantized model on all specified tasks
    all_metrics = {}
    cached_models = {} # Cache for quantized models
    quant_stats = {}
    logger.info("Evaluating quantized model on specified tasks...")

    for task_config in config['evaluation']['tasks']:
        task_type = task_config.get("type")
        logger.info(f"--- Running task: {task_type} ---")

        # Determine model class and quantization status
        model_key = 'causal_lm'
        if task_type == 'glue': #glue subset options: ['ax', 'cola', 'mnli', 'mnli_matched', 'mnli_mismatched', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']
            model_key = 'sequence_classification'

        if model_key not in cached_models:
            logger.info(f"Quantized model for '{model_key}' not found in cache. Loading and quantizing...")

            model_kwargs = {}
            if model_key == 'sequence_classification':
                ModelClass = AutoModelForSequenceClassification

                # Determine num_labels for the first GLUE task found
                task_name = task_config.get('dataset_subset')
                dataset_info = load_dataset("glue", task_name, split="train").features["label"]
                num_labels = dataset_info.num_classes
                model_kwargs['num_labels'] = num_labels
                logger.info(f"Task '{task_name}' requires {num_labels} labels.")
            else:
                ModelClass = AutoModelForCausalLM

            base_model = ModelClass.from_pretrained(
                model_config['name'],
                trust_remote_code=model_config.get('trust_remote_code', True),
                torch_dtype=torch.float16,
                #cache_dir=local_model_dir,
                device_map='auto', # CPU-QUANT: Comment out this line
                **model_kwargs
            )

            quantized_model = None
            start, end = 0, 0
            q_config = config.get('quantization', None)

            # Check for quantization or fp16 execution
            if q_config:
                start = time.time()
                quantized_model = apply_quantization(base_model, config, quant_stats).to(device)
                end = time.time()
                del base_model
            else:
                logger.info("FP16 selected, skipping quantization.")
                end, start = 0, 1
                quantized_model = base_model.to(device) # FP16 tests
            
            # Memory clear
            torch.cuda.empty_cache()
            gc.collect()

            logger.info(f"Quantization took {end - start:.2f} seconds.")
            quantized_model.eval()
            cached_models[model_key] = quantized_model
        else:
            logger.info(f"Using cached model for '{model_key}'.")
            quantized_model = cached_models[model_key]

        # Run evaluation
        if task_type == 'perplexity':
            perplexity = evaluate_perplexity(
                model=quantized_model,
                tokenizer=tokenizer,
                dataset_name=task_config.get('dataset'),
                dataset_subset=task_config.get('dataset_subset'),
                device=device,
                n_samples=task_config.get('num_samples', 128),
                seq_len=task_config.get('seq_len', 2048),
                batch_size=task_config.get('batch_size', 8)
            )
            all_metrics[f'perplexity_{task_config.get("dataset_subset")}'] = perplexity

        elif task_type in ['glue', 'summarization', 'qa', 'lambada']:
            downstream_metrics = evaluate_downstream_task(
                model=quantized_model,
                tokenizer=tokenizer,
                eval_config=task_config,
                device=device
            )
            metric_key = f'{task_type}_{task_config.get("dataset_subset", "")}'
            all_metrics[metric_key] = downstream_metrics
        else:
            logger.warning(f"Unsupported evaluation task type: {task_type}")

    # 5. Save results
    output_file = config['evaluation']['metric_output_file']
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    logger.info(f"Saving all metrics to: {output_file}")
    results = {
        "model": model_config['name'],
        "quantization_config": config.get('quantization', 'fp16'),
        "evaluation_configs": config['evaluation']['tasks'],
        "metrics": all_metrics,
        "quantization_time_sec": end - start,
        "quantization_stats": quant_stats,
    }
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=4)

    logger.info("Script finished successfully.")

if __name__ == "__main__":
    main()