from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TrainerCallback
from trl import GRPOTrainer, GRPOConfig
import re
import random
from video_datasets import OneVideoDataset, MixedVideoDataset
from peft import LoraConfig, TaskType, PeftModel
import os
from omegaconf import DictConfig
import hydra
from upload_qwen_to_s3 import upload_to_s3
from accelerate import Accelerator
from metric import calculate_VOC
import torch
import numpy as np
from scipy.stats import spearmanr
from metric import extract_task_completion_percentages
from omegaconf import OmegaConf
import json
import wandb
from accelerate.utils import set_seed

@torch.no_grad()
def custom_eval_logic(eval_videos, eval_offsets, model, processor, n_frames, num_shuffles, n_frames_ref):
    voc_results = {}
    eval_datasets = [OneVideoDataset(video_path=path, offset=offset) for path, offset in zip(eval_videos, eval_offsets)]
    
    for i, (dataset, video_path, offset) in enumerate(zip(eval_datasets, eval_videos, eval_offsets)):
        print(f"Evaluating dataset {i+1}/{len(eval_datasets)}: {video_path}")
        
        res = calculate_VOC(model, processor, dataset, n_frames=n_frames,
                            num_shuffles=num_shuffles, n_frames_ref=n_frames_ref)
        
        # Store results with dataset identifier
        dataset_name = f"dataset_{i+1}"
        voc_results[dataset_name] = {
            'voc_score': res['VOC'],
            'cnt_sorted': res['cnt_sorted']
        }
        print(dataset_name, " -> ", res['VOC'])
    return voc_results


@hydra.main(version_base=None, config_path="configs/hydra", config_name="config_eval")
def main(args: DictConfig):
    accelerator = Accelerator()
    set_seed(args.general.seed, device_specific=True)

    args = OmegaConf.to_container(args, resolve=True, throw_on_missing=True)

    model_path = args['model']['model_path']
    lora_base_path = args['model']['lora_path']

    processor = AutoProcessor.from_pretrained(model_path)

    eval_videos = args['eval']['eval_videos']
    eval_offsets = args['eval']['eval_offsets']

    # Determine checkpoints to evaluate
    checkpoints = []
    adapter_config_path = os.path.join(lora_base_path, "adapter_config.json")
    if os.path.exists(adapter_config_path):
        checkpoints = [lora_base_path]
    else:
        if os.path.isdir(lora_base_path):
            subdirs = [os.path.join(lora_base_path, d) for d in os.listdir(lora_base_path)
                      if os.path.isdir(os.path.join(lora_base_path, d))]
            checkpoints = sorted(subdirs)

    if not checkpoints:
        raise ValueError(f"No LoRA checkpoints found under: {lora_base_path}")

    all_results = {}
    for ckpt_path in checkpoints:
        print(f"Evaluating LoRA checkpoint: {ckpt_path}")

        # Load a fresh base model for each checkpoint
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_path,
            attn_implementation="sdpa"
        )

        pretrained_adapter = PeftModel.from_pretrained(model, ckpt_path)
        model = pretrained_adapter.merge_and_unload().to(accelerator.device)
        model.to(accelerator.device).eval()

        voc_results = custom_eval_logic(
            eval_videos,
            eval_offsets,
            model,
            processor,
            args['eval']['n_frames'],
            args['eval']['num_shuffles'],
            args['eval']['n_frames_ref']
        )

        ckpt_name = os.path.basename(ckpt_path.rstrip(os.sep))
        all_results[ckpt_name] = voc_results

    print(json.dumps(all_results, indent=2))


if __name__ == "__main__":
    main()