import argparse
import json
import logging
import os
import time
# --- Third-Party Libraries ---
import torch
import torch.distributed as dist
from accelerate import Accelerator
from accelerate.utils import gather_object
from datasets import load_dataset
from tqdm import tqdm
from torch.utils.data import DataLoader

# --- Local Project Imports ---
from config import (
    TASK_TYPE, 
    get_type, 
    TASK_PATHS_MAPPING, 
    MASK_ID_MAPPING, 
    MODEL_TYPE, 
    FINETUNING_TYPE
)
from data.basic_tools import CustomDataset
from eval.utils import get_eval_model
from eval.llada_generate import generate
from eval.eval import extract_pred_from_text_commonsense_170k,extract_last_option,extract_number

# --- Logging Configuration ---
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)


# ----------------------------------------------------------------------------
# MOCK FUNCTIONS REMOVED
# ----------------------------------------------------------------------------


def parse_args():
    """Parses command-line arguments."""
    parser = argparse.ArgumentParser(description="Run custom evaluation on commonsense tasks")
    
    # --- Argument Groups for better --help formatting ---

    task_group = parser.add_argument_group("Task & Model Arguments")
    task_group.add_argument("--model_name", type=str, required=True, help="Base model name.")
    task_group.add_argument("--task_name", type=str, required=True, choices=[e.value for e in TASK_TYPE], help="Name of the task to evaluate.")
    task_group.add_argument("--peft_name", type=str, default=None, help="PEFT model name or path, if any.")

    output_group = parser.add_argument_group("Output Arguments")
    output_group.add_argument("--output_dir", type=str, required=True, help="Directory to save results.")
    output_group.add_argument("--save_generations", action="store_true", help="Save all generated outputs to a JSON file.")

    data_group = parser.add_argument_group("Data Arguments")
    data_group.add_argument("--split", type=str, default="test", help="Dataset split to use (e.g., 'test', 'validation').")
    
    gen_group = parser.add_argument_group("Generation Config")
    gen_group.add_argument("--gen_length", type=int, default=256, help="Max new tokens to generate.")
    gen_group.add_argument("--steps", type=int, default=256, help="Steps for custom generate function.")
    gen_group.add_argument("--block_length", type=int, default=8, help="Block length for custom generate function.")
    gen_group.add_argument("--temperature", type=float, default=0.0, help="Generation temperature. 0 means greedy.")
    gen_group.add_argument("--cfg_scale", type=float, default=0.0, help="CFG scale for custom generate function.")
    gen_group.add_argument("--remasking", default="low_confidence", action="store_true", help="Use remasking in custom generate function.")
    gen_group.add_argument("--till_eos", action="store_true", default=False, help="Use Utill EOS token during generation.")
    gen_group.add_argument("--till_current_eos", action="store_true", default=False, help="Use Utill current EOS token during generation.")
    
    
    loading_group = parser.add_argument_group("Model Loading Arguments")
    loading_group.add_argument("--ft_task", type=str, default=None, help="Finetuning task name. Defaults to --task_name if not set.")
    loading_group.add_argument("--run_time", type=int, default=1)
    loading_group.add_argument("--f_form", type=str, default="linear")
    loading_group.add_argument("--training_mode", type=str, default="joint")
    loading_group.add_argument("--t_mapping", type=str, default="poly")
    loading_group.add_argument("--fnn_hidden_size", type=int, default=32)
    loading_group.add_argument("--fnn_hidden_size_2", type=int, default=512)
    loading_group.add_argument("--lr", type=float, default=5e-5)
    loading_group.add_argument("--clr", type=float, default=1e-4)
    loading_group.add_argument("--direct_noise", action="store_true", default=True)
    loading_group.add_argument("--ckpts", type=str, default="best")
    loading_group.add_argument("--zero_lora_init", action="store_true", default=False)
    loading_group.add_argument("--random_noise", action="store_true", default=False)
    loading_group.add_argument("--use_embedding", action="store_true", default=False)
    loading_group.add_argument("--embedding_dim", type=int, default=1)
    loading_group.add_argument("--init_c", type=str, default="kaiming_uniform_m")
    loading_group.add_argument("--input_mode", type=str, default="noise_level")
    loading_group.add_argument("--Embed_components", type=str, default="nd_nl")
    loading_group.add_argument("--Embed_type", type=str, default="fourier")
    loading_group.add_argument("--density_radius", type=int, default=0)
    loading_group.add_argument("--rank", type=int, default=32)
    loading_group.add_argument("--mapper_num_layers", type=int, default=2)
    loading_group.add_argument("--c_scale", type=float, default=1)
    loading_group.add_argument("--length_alignment", action="store_true", default=False)
    loading_group.add_argument("--whole_length", action="store_true", default=False)
    loading_group.add_argument("--stage_1", type=float, default=0)
    loading_group.add_argument("--scale_ab", type=float, default=1)
    
    #For ptuning
    loading_group.add_argument("--nvt", type=int, default=20)
    loading_group.add_argument("--h", type=int, default=550)
    loading_group.add_argument("--epoch", type=int, default=1)
    loading_group.add_argument("--ectype", type=str, default="LSTM")
    #For prompt tuning
    loading_group.add_argument("--prompt_tuning_init", type=str, default="TEXT")
    args = parser.parse_args()
    
    # Default ft_task to task_name if not provided
    if args.ft_task is None:
        args.ft_task = args.task_name
        
    return args


def setup_logging_and_output(args, accelerator):
    """Creates output directory and logs arguments on the main process."""
    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)
        logger.info("Accelerator initialized. Main process running.")
        logger.info(f"Evaluation arguments: {vars(args)}")


def load_task_dataset(task_name, split):
    """
    Loads the specified dataset and split based on task configuration.
    
    Returns:
        tuple: (dataset, q_key, gt_key)
    """
    logger.info(f"Loading dataset for task: {task_name}")
    task_type: TASK_TYPE = get_type(TASK_TYPE, task_name)
    task_info = TASK_PATHS_MAPPING[task_type]
    
    if "data_files" in task_info:
        if split not in task_info["data_files"]:
            raise ValueError(
                f"Split '{split}' is not defined in data_files for {task_type}. "
                f"Available splits: {list(task_info['data_files'].keys())}"
            )
        dataset = load_dataset(
            "json",
            data_files=task_info["data_files"],
            split=split
        )
    elif "path" in task_info:
        dataset = load_dataset(
            task_info["path"][0],
            task_info["path"][1],
            split=split
        )
    else:
        raise ValueError(f"Task info for {task_type} has no 'data_files' or 'path' key.")

    q_key = task_info["q_key"]
    gt_key = task_info.get("gt_key", None)
    
    if gt_key is None:
        raise ValueError(f"Task {task_name} must have a 'gt_key' in TASK_PATHS_MAPPING for evaluation.")

    logger.info(f"Dataset loaded. Found {len(dataset)} examples in split '{split}'.")
    return dataset, q_key, gt_key


def load_model_and_tokenizer(args):
    """
    Loads the model and tokenizer using the get_eval_model utility.
    
    Returns:
        tuple: (model, tokenizer, _)
    """
    logger.info("Loading model and tokenizer...")
    return get_eval_model(
        base_model_name=args.model_name,
        peft_name=args.peft_name,
        ft_task=args.ft_task,
        run_time=args.run_time,
        f_form=args.f_form,
        zero_lora_init=args.zero_lora_init,
        direct_noise=args.direct_noise,
        ckpts=args.ckpts,
        training_mode=args.training_mode,
        t_mapping=args.t_mapping,
        fnn_hidden_size=args.fnn_hidden_size,
        lr=args.lr,
        use_embedding=args.use_embedding,
        embedding_dim=args.embedding_dim,
        init_c=args.init_c,
        input_mode=args.input_mode,
        density_radius=args.density_radius,
        rank=args.rank,
        fnn_hidden_size_2=args.fnn_hidden_size_2,
        Embed_components=args.Embed_components,
        Embed_type=args.Embed_type,
        mapper_num_layers=args.mapper_num_layers,
        c_scale=args.c_scale,
        length_alignment=args.length_alignment,
        whole_length=args.whole_length,
        stage_1=args.stage_1,
        scale_ab=args.scale_ab,
        clr=args.clr,
        nvt=args.nvt,
        h=args.h,
        epoch=args.epoch,
        ectype=args.ectype,
        prompt_tuning_init=args.prompt_tuning_init,
    )


def run_evaluation_loop(eval_dataloader, model, tokenizer, accelerator, args, q_key, gt_key):
    """
    Runs the main evaluation loop over the dataloader.
    
    Returns:
        tuple: (local_correct, local_total, results)
    """

        
    local_total = 0
    local_correct = 0
    results = []
    
    evaluating_bar = tqdm(
        eval_dataloader,
        desc="Evaluating",
        disable=not accelerator.is_main_process,
        total=len(eval_dataloader),
    )

    # Determine mask_id and finetuning_type once before the loop
    mask_id = MASK_ID_MAPPING[MODEL_TYPE(args.model_name)]
    finetuning_type: FINETUNING_TYPE = get_type(FINETUNING_TYPE, args.peft_name) if args.peft_name else None
    if finetuning_type == FINETUNING_TYPE.CLORA:
        real_model = model.module if hasattr(model, "module") else model
        
        # [CHANGE START] ------------------------------------------------
        # If the config indicates pure Stage 1 (AB only), force the model logic
        # to Stage 1. This ensures Ceff is calculated as Identity, 
        # ignoring any noise/decay in the unused Mapper/Lambda parameters.
        if args.stage_1 == 1.0:
            if hasattr(real_model, "set_training_stage"):
                # We perform this check to avoid errors if model is raw
                real_model.set_training_stage(1)
                if accelerator.is_main_process:
                    # Optional: Print once to confirm
                    print("[Generate] Forcing CLoRA Stage 1 (Identity C) for evaluation.")
        # [CHANGE END] --------------------------------------------------

    for data in evaluating_bar:

        try:
            question = data[q_key][0]  # Get first item from batch
            gold_text = data[gt_key][0]  # Get first item from batch
            
            # -------------------------------
            # 1) Tokenize prompt
            # -------------------------------
            messages = [{"role": "user", "content": question}]
            prompt_data = tokenizer.apply_chat_template(
                messages, 
                return_tensors="pt", 
                return_dict=True, 
                add_generation_prompt=True
            )
            
            prompt_ids = prompt_data.input_ids.to(accelerator.device)
            question_length = prompt_ids.shape[1]

            # -------------------------------
            # 2) Generate model output
            # -------------------------------
            with torch.no_grad():
                gen_tokens: torch.Tensor = generate(
                    model=model,
                    tokenizer=tokenizer,
                    finetuning_type=finetuning_type,
                    direct_noise=args.direct_noise,
                    prompt=prompt_ids,
                    steps=args.steps,
                    gen_length=args.gen_length,
                    block_length=args.block_length,
                    temperature=args.temperature,
                    cfg_scale=args.cfg_scale,
                    remasking=args.remasking,
                    mask_id=mask_id,
                    is_main_process=accelerator.is_main_process,
                    random_noise=args.random_noise,
                    whole_length=args.whole_length,
                    till_eos=args.till_eos,
                    till_current_eos=args.till_current_eos,
                )  # shape: [1, prompt_length + gen_length]

            # -------------------------------
            # 3) Decode model output
            # -------------------------------
            answer_tokens = gen_tokens[:, question_length:]
            answer_decoded = tokenizer.batch_decode(
                answer_tokens, skip_special_tokens=True
            )[0]

            # -------------------------------
            # 4) Extract predicted answer
            # -------------------------------
            task_type: TASK_TYPE = get_type(TASK_TYPE, args.task_name)
            if task_type in (TASK_TYPE.ARC_CHALLENGE,
                             TASK_TYPE.COMMONSENSE170K,
                             TASK_TYPE.ARC_EASY,
                             TASK_TYPE.BOOLQ,
                             TASK_TYPE.HELLASWAG,
                             TASK_TYPE.OPENBOOKQA,
                             TASK_TYPE.PIQA,
                             TASK_TYPE.SOCIAL_I_QA,
                             TASK_TYPE.WINOGRANDE,
                             ):
                extracted_answer = extract_pred_from_text_commonsense_170k(
                    gold_text, answer_decoded
                )
            elif task_type in (TASK_TYPE.MATH14K,
                               TASK_TYPE.ADDSUB,
                               TASK_TYPE.SINGLEEQ,
                               TASK_TYPE.GSM8K_TEST,
                               TASK_TYPE.MULTIARITH,
                               TASK_TYPE.SVAMP,
                               ):
                extracted_answer = extract_number(
                    answer_decoded
                )
                gold_text = float(gold_text)
            elif task_type in (TASK_TYPE.AQUA,):
                
                extracted_answer = extract_last_option(
                    answer_decoded
                )
                
            else:
                raise ValueError(f"Unknown task type: {task_type}")
            # -------------------------------
            # 5) Update counters
            # -------------------------------
            local_total += 1
            is_correct = (extracted_answer == gold_text)
            local_correct += int(is_correct)
            
            # Save results on *all* processes
            if args.save_generations:
                results.append({
                    "question": question,
                    "gold_answer": gold_text,
                    "generated_text": answer_decoded,
                    "extracted_answer": extracted_answer,
                    "is_correct": is_correct
                })

            # Update postfix with *local* accuracy
            if local_total > 0:
                evaluating_bar.set_postfix(
                    {"acc (local)": f"{local_correct / local_total:.4f}"}
                )

        except Exception as e:
            logger.error(f"Error processing data point: {e}")
            logger.error(f"Data: {data}")
            continue
            
    return local_correct, local_total, results


def aggregate_and_save_results(local_correct, local_total, results, args, accelerator, total_time):
    """
    Aggregates results from all processes and saves them to JSON files
    on the main process.
    """
    logger.info("Aggregating results across all processes...")
    
    # Create tensors for local counts
    local_counts = torch.tensor([local_correct, local_total], dtype=torch.long, device=accelerator.device)
    
    # Reduce (sum) counts across all processes
    global_counts = accelerator.reduce(local_counts, reduction="sum")
    
    # Gather results list (if saving)
    gathered_results = gather_object(results) if args.save_generations else None
    
    
    # --- Finalize and Save Results (on main process) ---
    if accelerator.is_main_process:
        global_correct = global_counts[0].item()
        global_total = global_counts[1].item()

        if global_total == 0:
            logger.warning("No examples were processed. Cannot calculate accuracy.")
            accuracy = 0
        else:
            accuracy = global_correct / global_total
            logger.info("Evaluation Finished!")
            logger.info(f"Task: {args.task_name}")
            logger.info(f"Total: {global_total}, Correct: {global_correct}")
            logger.info(f"Accuracy: {accuracy:.6f}")
            
        num_gpus = accelerator.num_processes
        # Calculate GPU-Hours (Resource Usage)
        gpu_seconds = total_time * num_gpus

        time_consumption = {
            "wall_time_seconds": total_time,
            "wall_time_minutes": total_time / 60,
            "wall_time_hours": total_time / 3600,
            "total_gpu_hours": gpu_seconds / 3600,
            "num_gpus_used": num_gpus
        }
        time_consumption = {
            "seconds": total_time,
            "minutes": total_time / 60,
            "hours": total_time / 3600,
            "total_gpu_hours": gpu_seconds / 3600,
            "num_gpus_used": num_gpus
        }
        
        # Save summary
        summary = {
            "task": args.task_name,
            "model": args.model_name,
            "peft": args.peft_name,
            "accuracy": accuracy,
            "total_examples": global_total,
            "total_correct": global_correct,
            "time_consumption": time_consumption,
            "args": vars(args)
        }
        summary_path = os.path.join(args.output_dir, f"summary_{args.run_time}_blk_{args.block_length}_steps_{args.steps}.json")
        with open(summary_path, "w") as f:
            json.dump(summary, f, indent=4)
        logger.info(f"Summary saved to {summary_path}")
        
        # Save generations if requested
        if args.save_generations:
            generations_path = os.path.join(args.output_dir, f"generations_{args.run_time}_blk_{args.block_length}_steps_{args.steps}.json")
            with open(generations_path, "w") as f:
                json.dump(gathered_results, f, indent=4) 
            logger.info(f"Generated outputs saved to {generations_path}")


def cleanup(accelerator):
    """Cleans up accelerator and distributed processes."""
    accelerator.end_training()
    if dist.is_initialized():
        dist.destroy_process_group()


def main():
    # 1. Parse Arguments
    args = parse_args()
    
    # 2. Initialize Accelerator
    accelerator = Accelerator()
    torch.cuda.set_device(accelerator.device)
    
    # 3. Setup Logging and Output Dir
    setup_logging_and_output(args, accelerator)

    # 4. Load Dataset
    dataset, q_key, gt_key = load_task_dataset(args.task_name, args.split)

    # 5. Load Model and Tokenizer
    model, tokenizer, _ = load_model_and_tokenizer(args)
    # 6. Create DataLoader
    eval_dataloader = DataLoader(
        CustomDataset(dataset),
        batch_size=1,
        shuffle=False, 
        # pin_memory=True,
    )

    # 7. Prepare model and dataloader with Accelerator
    # 7. Prepare ONLY the dataloader
    # This ensures GPU 0 gets data chunk 0, GPU 1 gets data chunk 1, etc.
    eval_dataloader = accelerator.prepare(eval_dataloader) 
    
    # Manually move the model to the correct device for this process
    # Accelerator assigns a specific device (e.g., cuda:0, cuda:1) to this process automatically
    model = model.to(accelerator.device)
    
    model.eval()
    
    accelerator.wait_for_everyone() 
    
    if torch.cuda.is_available():
        torch.cuda.synchronize() # Wait for all GPU kernels to clear
        
    start_time = time.time()
    # 8. Run Evaluation Loop
    local_correct, local_total, results = run_evaluation_loop(
        eval_dataloader, model, tokenizer, accelerator, args, q_key, gt_key
    )
    
    accelerator.wait_for_everyone()
    
    if torch.cuda.is_available():
        torch.cuda.synchronize() # Wait for all GPU kernels to clear
        
    end_time = time.time()
    total_time = end_time - start_time
    # 9. Aggregate and Save Results
    aggregate_and_save_results(
        local_correct, local_total, results, args, accelerator,total_time
    )

    # 10. Clean up
    cleanup(accelerator)


if __name__ == "__main__":
    main()