from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TrainerCallback
from trl import GRPOTrainer, GRPOConfig
import re
import random
from video_datasets import OneVideoDataset, MixedVideoDataset
from peft import LoraConfig, TaskType
import os
from omegaconf import DictConfig
import hydra
from upload_qwen_to_s3 import upload_to_s3
from accelerate import Accelerator
from metric import calculate_VOC
import torch
import numpy as np
from scipy.stats import spearmanr
from metric import extract_task_completion_percentages
from omegaconf import OmegaConf
import json
import wandb
from accelerate.utils import set_seed


@torch.no_grad()
def custom_eval_logic(eval_videos, eval_offsets, model, processor, n_frames, num_shuffles, n_frames_ref):
    voc_results = {}
    eval_datasets = [OneVideoDataset(video_path=path, offset=offset) for path, offset in zip(eval_videos, eval_offsets)]
    
    for i, (dataset, video_path, offset) in enumerate(zip(eval_datasets, eval_videos, eval_offsets)):
        print(f"Evaluating dataset {i+1}/{len(eval_datasets)}: {video_path}")
        
        res = calculate_VOC(model, processor, dataset, n_frames=n_frames,
                            num_shuffles=num_shuffles, n_frames_ref=n_frames_ref)
        
        # Store results with dataset identifier
        dataset_name = f"dataset_{i+1}"
        voc_results[dataset_name] = {
            'voc_score': res['VOC'],
            'cnt_sorted': res['cnt_sorted']
        }
    return voc_results


def Single_reward(completions, ground_truth, **kwargs):
    rewards = []
    
    for completion, gt in zip(completions, ground_truth):
        try:
            pattern = r"(\d+(?:\.\d+)?)%"
            matches = re.findall(pattern, completion[0]['content'])
            res =  [float(value) for value in matches]
            pred_percent = res[-1] # if model predicts multiple values take the last one
        except:
            pred_percent = random.randint(0, 100)

        reward = -abs(float(pred_percent) - gt) / 100
        rewards.append(reward)
    return rewards


def VOC_reward(completions, image, shuffled_indices, **kwargs):
    rewards = []
    
    for completion, shuffle in zip(completions, shuffled_indices):
        unordered_values = [-1.0] + extract_task_completion_percentages(completion[0]['content'], len(image[0]) - 1)

        if len(unordered_values) == len(image[0]):
            values_order = np.argsort(unordered_values)
            order = np.array([0] + shuffle)[values_order]
            true_order = np.arange(len(order))

            ## compute VOC here
            VOC, _ = spearmanr(order[1:] - 1, true_order[1:] - 1)
        else:
            VOC = -1
        rewards.append(VOC)
    return rewards


class CustomGRPOTrainer(GRPOTrainer):
    def evaluation_loop(self,
        dataloader,
        description,
        prediction_loss_only,
        ignore_keys,
        metric_key_prefix,
    ):
        output = super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
        # Add your custom evaluation logic here
        print("Running custom evaluation logic...")
        voc_results = custom_eval_logic(self.eval_videos, self.eval_offsets, self.model, self.processing_class, self.eval_n_frames,
                                    self.eval_num_shuffles, self.eval_n_frames_ref)

        log_result = {}
        print("\n=== VOC Results Summary ===")
        for dataset_name, results in voc_results.items():
            log_result[f"{dataset_name}_VOC"] = results['voc_score']
            log_result[f"{dataset_name}_sorted"] = results['cnt_sorted']
            print(f"{dataset_name}: VOC = {results['voc_score']:.4f}, Sorted = {results['cnt_sorted']}")

        output.metrics.update(log_result)
        # Return combined metrics
        return output


class UploadCheckpointCallback(TrainerCallback):
    def __init__(self, upload_every_n_steps, remote_path, run_name):
        self.upload_every_n_steps = upload_every_n_steps
        self.remote_path = remote_path
        self.run_name = run_name

    def on_save(self, args, state, control, **kwargs):
        if state.global_step % self.upload_every_n_steps == 0 and state.global_step > 0:
            checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}/")
            
            if os.path.exists(checkpoint_dir):
                print(f"Uploading checkpoint {state.global_step} to server...")

                s3_bucket = ""  # Replace with your bucket name
                s3_path = os.path.join(self.remote_path, self.run_name, f"checkpoint-{state.global_step}")
                upload_to_s3(checkpoint_dir, s3_bucket, s3_path)
            else:
                print(f"Checkpoint {checkpoint_dir} does not exist")

        return control


class UploadFirstHighVOCCallback(TrainerCallback):
    def __init__(self, remote_path, run_name, accelerator, target_VOC):
        self.remote_path = remote_path
        self.run_name = run_name
        self.target = target_VOC
        self.saved = False
        self.accelerator = accelerator

    def on_log(self, args, state, control, model, logs=None, **kwargs):
        voc = logs.get("rewards/VOC_reward/mean")
        if voc is not None and voc >= self.target and not self.saved and self.accelerator.is_main_process: 
            checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}/")
            model.save_pretrained(checkpoint_dir)
            self.saved = True
            s3_bucket = ""
            s3_path = os.path.join(self.remote_path, self.run_name, f"checkpoint-target")
            upload_to_s3(checkpoint_dir, s3_bucket, s3_path)

        return control


class LogConfigCallback(TrainerCallback):
    def __init__(self, cfg, accelerator):
        self.cfg = cfg
        self.logged = False
        self.accelerator = accelerator

    def on_log(self, args, state, control, **kwargs):
        if self.accelerator.is_main_process and not self.logged:
            wandb.config.update(self.cfg)  # W&B is initialized now
            self.logged = True
    

@hydra.main(version_base=None, config_path="configs/hydra", config_name="config")
def main(args: DictConfig):
    accelerator = Accelerator()
    set_seed(args.grpo.seed, device_specific=True)

    args = OmegaConf.to_container(args, resolve=True, throw_on_missing=True)

    os.makedirs(args['general']['save_path'], exist_ok=True)

    mixed_vid_dataset = MixedVideoDataset(**args['video'])

    if args['general']['use_peft']:
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            **args['peft']
        )
    else:
        peft_config = None

    grpo_config = GRPOConfig(
        **args['grpo']
    )

    model_path = args['model']['model_path']
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_path,
        attn_implementation="sdpa"
    )
    processor = AutoProcessor.from_pretrained(model_path)

    if args['video']['reward_type'] == 'VOC':
        reward_fn = VOC_reward
    elif args['video']['reward_type'] == 'Single':
        reward_fn = Single_reward
    else:
        raise NotImplementedError
    eval_videos = args['eval']['eval_videos']
    eval_offsets = args['eval']['eval_offsets']
    eval_dataset = OneVideoDataset(video_path=eval_videos[0], offset=eval_offsets[0])
        
    trainer = CustomGRPOTrainer(
        model=model,
        processing_class=processor,
        reward_funcs=reward_fn,
        train_dataset=mixed_vid_dataset,
        args=grpo_config,
        peft_config=peft_config,
        eval_dataset=eval_dataset,
        callbacks=[
            UploadCheckpointCallback(
                upload_every_n_steps=args['grpo']['save_steps'],
                remote_path="<your upload path>",
                run_name=args['grpo']['run_name']
            ),
            LogConfigCallback(cfg=args, accelerator=accelerator),
            UploadFirstHighVOCCallback( 
                remote_path="<your upload path>",
                run_name=args['grpo']['run_name'],
                accelerator=accelerator,
                target_VOC=args['general']['target_VOC']
            )
        ]
    )
    trainer.eval_n_frames = args['eval']['n_frames']
    trainer.eval_num_shuffles = args['eval']['num_shuffles']
    trainer.eval_n_frames_ref = args['eval']['n_frames_ref']
    trainer.eval_videos = args['eval']['eval_videos']
    trainer.eval_offsets = args['eval']['eval_offsets']
    trainer.save_results_path = os.path.join(args['general']['save_path'], f"voc_results_{args['grpo']['run_name']}.json")

    trainer.train()

    if accelerator.is_main_process and args['general']['save_last']:
        chkpt_path = os.path.join(args['general']['save_path'], 'checkpoints')
        os.makedirs(chkpt_path, exist_ok=True)
        model_save_path = os.path.join(chkpt_path, args['grpo']['run_name'])
        trainer.model.save_pretrained(model_save_path)
        s3_bucket = ""  # Replace with your bucket name
        s3_path = ""
        upload_to_s3(model_save_path, s3_bucket, s3_path)


if __name__ == "__main__":
    main()