import numpy as np
import os
import hydra
import evaluate
import json
import logging
from typing import Dict, List, Tuple, Any, Optional, Union

import torch
import torch.nn as nn 
import torch.distributed as dist

from tqdm import tqdm
from pathlib import Path
from rouge_score import rouge_scorer
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

from utils import get_model_identifiers_from_yaml
from eval_utils import get_forget_quality, get_model_utility
from data_module import TextDatasetQA, get_batch_loss, custom_data_collator_with_indices

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


def setup_distributed() -> Tuple[int, int, bool]:
    """
    Set up distributed training environment.
    
    Returns:
        tuple: (local_rank, world_size, is_distributed)
    """
    local_rank = int(os.environ.get('LOCAL_RANK', '0'))
    world_size = int(os.environ.get('WORLD_SIZE', '1'))
    is_distributed = world_size > 1
    
    if is_distributed and not dist.is_initialized():
        try:
            dist.init_process_group(backend="nccl")
            torch.cuda.set_device(local_rank)
            logger.info(f"Initialized distributed training with rank {local_rank}/{world_size}")
        except Exception as e:
            logger.error(f"Failed to initialize distributed environment: {e}")
            raise
    
    return local_rank, world_size, is_distributed


def eval_perturbation_ratio(eval_dataloader, perturb_dataloader, model):
    """
    Evaluate perturbation ratio by comparing original and perturbed losses.
    
    Args:
        eval_dataloader: DataLoader for evaluation data
        perturb_dataloader: DataLoader for perturbed data
        model: Model to evaluate
        
    Returns:
        dict: Evaluation metrics
    """
    eval_logs = {}
    device = model.device
    
    for batch, perturb_batch in tqdm(zip(eval_dataloader, perturb_dataloader), 
                                    desc="Evaluating perturbation ratio"):
        input_ids, labels, attention_mask, indices = batch
        batch = {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}
        
        perturb_input_ids, perturb_labels, perturb_attention_mask, _ = perturb_batch
        
        # Reshape perturbed data if needed
        if len(perturb_input_ids.shape) > 2:
            bsz, seq_len = perturb_input_ids.shape[0:2]
        else:
            bsz = perturb_input_ids.shape[0]
            seq_len = 1
            
        perturb_batch = {
            "input_ids": perturb_input_ids.view(bsz*seq_len, -1), 
            "labels": perturb_labels.view(bsz*seq_len, -1), 
            "attention_mask": perturb_attention_mask.view(bsz*seq_len, -1)
        }

        # Send to device
        for k, v in batch.items():
            batch[k] = v.to(device)
        for k, v in perturb_batch.items():
            perturb_batch[k] = v.to(device)

        with torch.no_grad():
            outputs = model(**batch)
            perturb_outputs = model(**perturb_batch)

        # Calculate losses
        gt_loss = get_batch_loss(outputs.logits, batch['labels'])
        perturb_loss = get_batch_loss(perturb_outputs.logits, perturb_batch['labels']).view(bsz, seq_len)

        # Count tokens
        num_token_gt = (batch['labels'] != -100).sum(-1)
        num_token_perturb = (perturb_batch['labels'] != -100).view(bsz, seq_len, -1).sum(-1)

        # Calculate per-token metrics
        mean_perturb_loss = perturb_loss.mean(dim=1)
        ratio = (mean_perturb_loss - gt_loss).mean()
        
        perturb_loss_per_token = perturb_loss / num_token_perturb
        gt_loss_per_token = gt_loss / num_token_gt
        truth_ratio = torch.exp(gt_loss_per_token - perturb_loss_per_token.mean(-1))

        # Convert to numpy for JSON serialization
        indices_np = indices.cpu().numpy().tolist()
        perturb_loss_per_token_np = perturb_loss_per_token.cpu().float().numpy().tolist()
        gt_loss_per_token_np = gt_loss_per_token.cpu().float().numpy().tolist()
        truth_ratio_np = truth_ratio.cpu().float().numpy().tolist()
        gt_loss_np = gt_loss.cpu().float().numpy().tolist()
        perturb_loss_np = perturb_loss.cpu().float().numpy().tolist()
        num_token_gt_np = num_token_gt.cpu().float().numpy().tolist()
        num_token_perturb_np = num_token_perturb.cpu().float().numpy().tolist()

        # Store results
        metrics = {
            'average_perturb_loss': dict(zip(indices_np, perturb_loss_per_token_np)),
            'avg_paraphrased_loss': dict(zip(indices_np, gt_loss_per_token_np)),
            'truth_ratio': dict(zip(indices_np, truth_ratio_np)),
            'paraphrased_loss': dict(zip(indices_np, gt_loss_np)),
            'perturb_loss': dict(zip(indices_np, perturb_loss_np)),
            'num_token_paraphrased': dict(zip(indices_np, num_token_gt_np)),
            'num_token_perturb': dict(zip(indices_np, num_token_perturb_np)),
        }
        
        # Initialize metrics in eval_logs if not present
        for metric_name in metrics:
            if metric_name not in eval_logs:
                eval_logs[metric_name] = {}
            eval_logs[metric_name].update(metrics[metric_name])

    return eval_logs


def get_dataloader(cfg, eval_task, tokenizer, folder, split, question_key, answer_key, base_answer_key, perturbed_answer_key):
    """
    Create dataloaders for evaluation, base, and perturbed datasets.
    
    Args:
        cfg: Configuration object
        eval_task: Current evaluation task
        tokenizer: Tokenizer for the model
        folder: Data folder path
        split: Data split to use
        question_key: Key for questions in the dataset
        answer_key: Key for answers in the dataset
        base_answer_key: Key for base answers in the dataset
        perturbed_answer_key: Key for perturbed answers in the dataset
        
    Returns:
        tuple: (eval_dataloader, base_eval_dataloader, perturb_dataloader)
    """
    # Create main evaluation dataset
    torch_format_dataset = TextDatasetQA( 
        folder, 
        tokenizer=tokenizer, 
        model_family=cfg.model_family, 
        max_length=cfg.generation.max_length, 
        split=split, 
        question_key=question_key, 
        answer_key=answer_key
    )
    
    # Create base dataset
    base_torch_format_dataset = TextDatasetQA(
        folder,
        tokenizer=tokenizer, 
        model_family=cfg.model_family, 
        max_length=cfg.generation.max_length, 
        split=split, 
        question_key=question_key, 
        answer_key=base_answer_key
    )

    # Create perturbed dataset
    perturb_torch_format_dataset = TextDatasetQA(
        folder,
        tokenizer=tokenizer, 
        model_family=cfg.model_family, 
        max_length=cfg.generation.max_length, 
        split=split, 
        question_key=question_key, 
        answer_key=perturbed_answer_key
    )

    # Limit dataset size if configured
    if cfg.ds_size:
        torch_format_dataset.data = torch_format_dataset.data.select(
            range(min(cfg.ds_size, len(torch_format_dataset.data)))
        )
        base_torch_format_dataset.data = base_torch_format_dataset.data.select(
            range(min(cfg.ds_size, len(base_torch_format_dataset.data)))
        )
        perturb_torch_format_dataset.data = perturb_torch_format_dataset.data.select(
            range(min(cfg.ds_size, len(perturb_torch_format_dataset.data)))
        )

    # Create DataLoaders
    eval_dataloader = torch.utils.data.DataLoader(
        torch_format_dataset, batch_size=cfg.batch_size, 
        collate_fn=custom_data_collator_with_indices
    )
    
    base_eval_dataloader = torch.utils.data.DataLoader(
        base_torch_format_dataset, batch_size=cfg.batch_size//4, 
        collate_fn=custom_data_collator_with_indices
    )
    
    perturb_dataloader = torch.utils.data.DataLoader(
        perturb_torch_format_dataset, batch_size=cfg.batch_size//4, 
        collate_fn=custom_data_collator_with_indices
    )

    return eval_dataloader, base_eval_dataloader, perturb_dataloader


def get_all_evals(cfg, model, tokenizer, eval_task, eval_dataloader, base_eval_dataloader, perturb_dataloader, normalize_gt=False):
    """
    Run comprehensive evaluations on the model.
    
    Args:
        cfg: Configuration object
        model: Model to evaluate
        tokenizer: Tokenizer for the model
        eval_task: Current evaluation task
        eval_dataloader: DataLoader for evaluation data
        base_eval_dataloader: DataLoader for base data
        perturb_dataloader: DataLoader for perturbed data
        normalize_gt: Whether to normalize ground truth losses
        
    Returns:
        dict: Comprehensive evaluation metrics
    """
    eval_logs = {}
    device = model.device

    gen_outputs = []
    ground_truths = []
    input_strings = []
    all_indices = []

    for batch in tqdm(eval_dataloader, desc=f"Evaluating {eval_task}"):
        input_ids, labels, attention_mask, indices = batch
        all_indices.extend(indices.cpu().numpy().tolist())
        
        batch = {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}
        for k, v in batch.items():
            batch[k] = v.to(device)

        with torch.no_grad():
            # Get model outputs
            outputs = model(**batch)
            # Generate text
            input_string, gen_output, gt = run_generation(cfg, batch, model, tokenizer=tokenizer)
            gen_outputs.extend(gen_output)
            ground_truths.extend(gt)
            input_strings.extend(input_string)

        # Calculate loss metrics
        gt_loss = get_batch_loss(outputs.logits, batch['labels'])
        num_token_gt = (batch['labels'] != -100).sum(-1)
        gt_loss_per_token = gt_loss / num_token_gt

        # Convert to numpy for JSON serialization
        indices_np = indices.cpu().numpy().tolist()
        gt_loss_per_token_np = gt_loss_per_token.cpu().type(torch.DoubleTensor).numpy().tolist()
        gt_loss_np = gt_loss.cpu().type(torch.DoubleTensor).numpy().tolist()
        num_token_gt_np = num_token_gt.type(torch.DoubleTensor).cpu().numpy().tolist()
        
        # Save metrics
        metrics = {
            'avg_gt_loss': dict(zip(indices_np, gt_loss_per_token_np)),
            'gt_loss': dict(zip(indices_np, gt_loss_np)),
            'num_token_gt': dict(zip(indices_np, num_token_gt_np)),
            'generated_text': dict(zip(indices_np, list(zip(input_string, gen_output, gt))))
        }
        
        # Initialize metrics in eval_logs if not present
        for metric_name in metrics:
            if metric_name not in eval_logs:
                eval_logs[metric_name] = {}
            eval_logs[metric_name].update(metrics[metric_name])

        # Compute conditional probability
        cond_prob = torch.exp(-gt_loss_per_token.float())
        if 'conditional_probability' not in eval_logs:
            eval_logs['conditional_probability'] = {}
        eval_logs['conditional_probability'].update(
            dict(zip(indices_np, cond_prob.cpu().numpy().tolist()))
        )

    # Add ROUGE metrics
    eval_logs.update(eval_rouge_recall(gen_outputs, ground_truths, all_indices))

    # Add perturbation metrics
    perturbation_logs = eval_perturbation_ratio(base_eval_dataloader, perturb_dataloader, model)
    eval_logs.update(perturbation_logs)

    # Normalize ground truth loss if requested
    if normalize_gt:
        avg_gt_loss = eval_logs['avg_gt_loss']
        avg_perturb_loss = eval_logs['average_perturb_loss']
        data_indices = list(avg_gt_loss.keys())
        normalized_gt_loss = {}
        
        for idx in data_indices:
            truth_prob = np.exp(-1 * avg_gt_loss[idx])
            perturb_prob = np.exp(-1 * np.array(avg_perturb_loss[idx]))
            all_prob = np.array([truth_prob, *perturb_prob])
            normalized_gt_prob = truth_prob / all_prob.sum()
            normalized_gt_loss[idx] = -1 * np.log(normalized_gt_prob)

        eval_logs['normalized_gt_loss'] = normalized_gt_loss
    
    return eval_logs


def eval_accuracy(logits, labels):
    """
    Calculate accuracy from model outputs.
    
    Args:
        logits: Model output logits
        labels: True labels
        
    Returns:
        dict: Accuracy metrics
    """
    preds = logits.argmax(-1)
    shifted_labels = labels[..., 1:].contiguous()
    # Mask out padding tokens
    mask = (shifted_labels != -100)
    acc = (preds[..., :-1] == shifted_labels).float()
    acc *= mask.float()
    acc = acc.sum() / mask.float().sum()

    return {"eval accuracy": acc.item()}


def run_generation(cfg, batch, model, tokenizer):
    """
    Run text generation on a batch of inputs.
    
    Args:
        cfg: Configuration object
        batch: Input batch
        model: Model to generate with
        tokenizer: Tokenizer for the model
        
    Returns:
        tuple: (input_strings, generated_outputs, ground_truths)
    """
    input_ids = batch["input_ids"]
    input_strings = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
    
    # Determine split symbol based on model family
    split_symbol = " [/INST]" if cfg.model_family == 'llama2-7b' else 'Answer: '
    
    # Split inputs and ground truths
    try:
        ground_truth = [s.split(split_symbol)[1] for s in input_strings]
        input_strings = [s.split(split_symbol)[0] for s in input_strings]
    except IndexError:
        logger.warning(f"Could not split using '{split_symbol}'. Using default processing.")
        ground_truth = ["" for _ in input_strings]
    
    # Add split symbol back to inputs
    if cfg.model_family == 'llama2-7b':
        input_strings = [s + split_symbol for s in input_strings]

    # Configure tokenizer for left padding
    left_pad_tokenizer = tokenizer
    left_pad_tokenizer.padding_side = 'left'
    left_pad_tokenizer.pad_token = left_pad_tokenizer.eos_token
    left_pad_tokenizer.pad_token_id = left_pad_tokenizer.eos_token_id

    # Encode inputs
    inputs = left_pad_tokenizer.batch_encode_plus(
        input_strings, 
        add_special_tokens=True, 
        return_tensors='pt', 
        padding=True
    ).to(model.device)
    
    # Generate outputs
    try:
        out = model.generate(
            inputs.input_ids, 
            attention_mask=inputs.attention_mask, 
            max_length=cfg.generation.max_length, 
            max_new_tokens=cfg.generation.max_new_tokens, 
            do_sample=False, 
            use_cache=True, 
            pad_token_id=left_pad_tokenizer.eos_token_id
        )
        # Decode only the new tokens
        strs = left_pad_tokenizer.batch_decode(
            out[:, inputs.input_ids.shape[-1]:], 
            skip_special_tokens=True
        )
    except Exception as e:
        logger.error(f"Error during generation: {e}")
        strs = ["" for _ in range(len(input_strings))]
    
    return input_strings, strs, ground_truth


def eval_bleu(gen_outputs, ground_truths):
    """
    Evaluate BLEU and ROUGE scores.
    
    Args:
        gen_outputs: Generated outputs
        ground_truths: Ground truth outputs
        
    Returns:
        dict: BLEU and ROUGE metrics
    """
    try:
        rouge = evaluate.load('rouge')
        bleu = evaluate.load('bleu')
        rouge_res = rouge.compute(predictions=gen_outputs, references=ground_truths)
        bleu_res = bleu.compute(predictions=gen_outputs, references=ground_truths)

        eval_result = {
            'rouge': rouge_res,
            'bleu': bleu_res,
        }
    except Exception as e:
        logger.error(f"Error computing BLEU/ROUGE: {e}")
        eval_result = {
            'rouge': {},
            'bleu': {},
        }

    return eval_result


def eval_rouge_recall(gen_outputs, ground_truths, indices):
    """
    Evaluate ROUGE recall scores.
    
    Args:
        gen_outputs: Generated outputs
        ground_truths: Ground truth outputs
        indices: Data indices
        
    Returns:
        dict: ROUGE recall metrics by index
    """
    try:
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
        rouge1_recall = {}
        rougeL_recall = {}
        
        for gen, gt, idx in zip(gen_outputs, ground_truths, indices):
            rouge_scores = scorer.score(gt, gen)
            rouge1_recall[idx] = rouge_scores['rouge1'].recall
            rougeL_recall[idx] = rouge_scores['rougeL'].recall

        return {'rouge1_recall': rouge1_recall, 'rougeL_recall': rougeL_recall}
    except Exception as e:
        logger.error(f"Error computing ROUGE recall: {e}")
        return {'rouge1_recall': {}, 'rougeL_recall': {}}


def print_rouge(file_path):
    """
    Print ROUGE metrics from a results file.
    
    Args:
        file_path: Path to results file
    """
    try:
        with open(file_path, 'r') as file:
            data = json.load(file)

        split_names = {
            "eval_forget.json": "Forget Set", 
            "eval_retain.json": "Retain Set",
            "eval_real_author_wo_options.json": "Real Author Set", 
            "eval_real_world_wo_options.json": "Real World Fact Set", 
        }

        for i in sorted(list(data.keys())):
            if i in data and "rouge1_recall" in data[i] and "rougeL_recall" in data[i]:
                rouge1, rougeL = data[i]["rouge1_recall"], data[i]["rougeL_recall"]
                if rouge1 and rougeL:
                    avg_rouge1, avg_rougeL = np.mean(list(rouge1.values())), np.mean(list(rougeL.values()))
                    split = split_names.get(i, i)
                    print(f">>>>> {split}: \n      rouge1 = {avg_rouge1:.4f}, \n      rougeL = {avg_rougeL:.4f}")
    except Exception as e:
        logger.error(f"Error printing ROUGE metrics: {e}")


def load_model(cfg, device_map=None):
    """
    Load model from checkpoint or pretrained weights.
    
    Args:
        cfg: Configuration object
        device_map: Device mapping for distributed training

    Returns:
        tuple: (model, tokenizer)
    """
    model_cfg = get_model_identifiers_from_yaml(cfg.model_family)
    model_id = model_cfg["hf_key"]

    # Load tokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        tokenizer.pad_token = tokenizer.eos_token
    except Exception as e:
        logger.error(f"Error loading tokenizer: {e}")
        raise

    # Load model
    model = None
    config = AutoConfig.from_pretrained(model_id)
    use_flash_attention = cfg.use_flash_attention_2 and model_cfg.get("flash_attention2", "false") == "true"

    # Try loading model with multiple attempts
    for attempt in range(3):
        try:
            if cfg.use_pretrained:
                logger.info(f"Loading pretrained model from {model_id}")
                model = AutoModelForCausalLM.from_pretrained(
                    model_id, 
                    config=config, 
                    use_flash_attention_2=use_flash_attention, 
                    torch_dtype=torch.bfloat16, 
                    trust_remote_code=True, 
                    device_map=device_map
                )
            else:
                logger.info(f"Loading checkpoint from {cfg.model_path}")
                model = AutoModelForCausalLM.from_pretrained(
                    cfg.model_path, 
                    config=config, 
                    use_flash_attention_2=use_flash_attention, 
                    torch_dtype=torch.bfloat16, 
                    trust_remote_code=True
                )
                if device_map is None:
                    model = model.cuda()

            # Break the loop if model loads successfully
            break
        except Exception as e:
            logger.warning(f"Attempt {attempt+1}/3 failed: {e}")
            if attempt == 2:
                logger.error("All attempts to load model failed")
                raise

    # Set model to evaluation mode
    model = model.eval()

    # Reinitialize weights if configured
    if cfg.reinitialize_weights:
        logger.info("Reinitializing model weights")
        reinitialize_weights(model)

    return model, tokenizer


def reinitialize_weights(model) -> None:
    """Reinitialize weights of linear layers."""
    for module in model.modules():
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0, std=0.02)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)


def aggregate_results(cfg, local_rank):
    """
    Aggregate evaluation results and compute utility metrics.

    Args:
        cfg: Configuration object
        local_rank: Local process rank

    Returns:
        dict: Aggregated metrics
    """
    # Skip if not rank 0 in distributed setting
    if local_rank != 0:
        return None

    # Check if required files exist
    if not cfg.retain_result or not cfg.ckpt_result:
        logger.warning("Skipping aggregation: retain_result or ckpt_result not provided")
        return None

    # Check both files before trying to load
    if not os.path.exists(cfg.retain_result):
        logger.error(f"Retain result file not found: {cfg.retain_result}")
        return None

    if not os.path.exists(cfg.ckpt_result):
        logger.error(f"Checkpoint result file not found: {cfg.ckpt_result}")
        return None

    try:
        # Load results
        retain_result = json.load(open(cfg.retain_result))
        ckpt_result = json.load(open(cfg.ckpt_result))

        # Compute metrics
        model_utility = get_model_utility(retain_result)
        forget_quality = get_forget_quality(retain_result, ckpt_result)
        model_utility['Forget Quality'] = forget_quality['Forget Quality']

        # Print results
        print(">>>>>>> Aggregation Finished:\n")

        eval_nm = ["ROUGE", "Prob.", "Truth Ratio"]
        set_nm = ["Forget", "Retain", "Real Authors", "Real World"]
        extra_nm = ["Model Utility", "Forget Quality"]

        for i in eval_nm:
            for j in set_nm:
                k = i + " " + j
                if k in model_utility:
                    print(f"{k}: {model_utility[k]}")

        for k in extra_nm:
            if k in model_utility:
                print(f"{k}: {model_utility[k]}")

        return model_utility
    except Exception as e:
        logger.error(f"Error during aggregation: {e}")
        return None


@hydra.main(version_base=None, config_path="config", config_name="eval_everything")
def main(cfg):
    """
    Main function for model evaluation.
    
    Args:
        cfg: Hydra configuration object
    """
    # Validate configuration
    input_lengths = [len(cfg.data_path), len(cfg.split_list), len(cfg.eval_task), 
                    len(cfg.question_key), len(cfg.answer_key), 
                    len(cfg.base_answer_key), len(cfg.perturbed_answer_key)]

    if len(set(input_lengths)) != 1:
        raise ValueError(
            "Configuration error: data_path, split_list, eval_task, question_key, "
            "answer_key, base_answer_key, and perturbed_answer_key must have the same length"
        )

    # Set up distributed environment
    local_rank, world_size, is_distributed = setup_distributed()

    # Create save directory
    Path(cfg.save_dir).mkdir(parents=True, exist_ok=True)

    # Disable wandb if not needed
    os.environ["WANDB_DISABLED"] = "true"

    # Set device map for distributed training
    device_map = {'': local_rank} if is_distributed else None

    # Load model and tokenizer
    model, tokenizer = load_model(cfg, device_map)

    # Initialize aggregated results
    aggregated_eval_logs = {}

    # Run evaluations for each task
    for i, (folder, split, question_key, answer_key, eval_task, base_answer_key, perturbed_answer_key) in enumerate(
        zip(cfg.data_path, cfg.split_list, cfg.question_key, cfg.answer_key, 
            cfg.eval_task, cfg.base_answer_key, cfg.perturbed_answer_key)
    ):
        # Determine save filename based on distributed setup
        base_filename = f"{eval_task}.json"
        if is_distributed:
            save_filename = os.path.join(cfg.save_dir, f"{eval_task}_{local_rank}.json")
        else:
            save_filename = os.path.join(cfg.save_dir, base_filename)

        logger.info(f'Working on eval task {eval_task} with split {split}')

        # Skip if file exists and overwrite is not enabled
        if os.path.exists(save_filename) and not cfg.overwrite:
            logger.info(f"Skipping {eval_task} because {save_filename} already exists")
            if is_distributed:
                # Ensure all processes skip together
                torch.distributed.barrier()
            continue

        # Get dataloaders
        eval_dataloader, base_eval_dataloader, perturb_dataloader = get_dataloader(
            cfg, eval_task, tokenizer, folder, split, 
            question_key, answer_key, base_answer_key, perturbed_answer_key
        )

        # Determine whether to normalize ground truth
        normalize_gt = 'eval_log' not in eval_task

        # Run evaluations
        eval_logs = get_all_evals(
            cfg, model, tokenizer, eval_task, 
            eval_dataloader, base_eval_dataloader, perturb_dataloader, 
            normalize_gt=normalize_gt
        )

        # Save results
        try:
            with open(save_filename, "w") as f:
                json.dump(eval_logs, f, indent=4)
            logger.info(f"Results saved to {save_filename}")
        except Exception as e:
            logger.error(f"Error saving results to {save_filename}: {e}")

        # Add to aggregated logs
        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

        # Synchronize processes in distributed setting
        if is_distributed:
            torch.distributed.barrier()

    # Save aggregated logs (rank 0 only in distributed setting)
    if local_rank == 0:
        aggregated_log_filename = os.path.join(cfg.save_dir, "eval_log_aggregated.json")

        if not os.path.exists(aggregated_log_filename) or cfg.overwrite:
            try:
                with open(aggregated_log_filename, "w") as f:
                    json.dump(aggregated_eval_logs, f, indent=4)
                logger.info(f"Aggregated results saved to {aggregated_log_filename}")
            except Exception as e:
                logger.error(f"Error saving aggregated results: {e}")

    # Synchronize before aggregation
    if is_distributed:
        torch.distributed.barrier()

    logger.info("Evaluation Completed")

    # Try to run aggregation if files are specified
    if hasattr(cfg, 'retain_result') and hasattr(cfg, 'ckpt_result'):
        logger.info("Starting aggregation...")
        try:
            # Only try aggregation if we have paths and at least one exists
            if (cfg.retain_result and os.path.exists(cfg.retain_result)) or \
               (cfg.ckpt_result and os.path.exists(cfg.ckpt_result)):
                model_utility = aggregate_results(cfg, local_rank)
            else:
                if local_rank == 0:
                    logger.warning("Skipping aggregation: result files not found")
                    if cfg.retain_result:
                        logger.warning(f"retain_result path: {cfg.retain_result}")
                    if cfg.ckpt_result:
                        logger.warning(f"ckpt_result path: {cfg.ckpt_result}")
        except Exception as e:
            if local_rank == 0:
                logger.error(f"Error during aggregation: {e}")
    else:
        if local_rank == 0:
            logger.info("Skipping aggregation: retain_result or ckpt_result not provided")

if __name__ == "__main__":
    main()
