from trl import GRPOTrainer, GRPOConfig
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TrainerCallback
from typing import Union, Any, List, Dict
from pyboy_gym import GenericPyBoyEnv
from pyboy import PyBoy
from collections import deque
import numpy as np
import torch
import torch.nn.functional as F
from a2c_ppo_acktr.rl_utils import make_observation
from qwen_vl_utils import process_vision_info
from a2c_ppo_acktr.rl_utils import text_projection
from torch.utils.data import Dataset
from voc_calculation import extract_task_completion_percentages
from scipy.stats import spearmanr
import hydra
from omegaconf import DictConfig, OmegaConf
import wandb
from accelerate.utils import set_seed
import os
from accelerate import Accelerator
from peft import LoraConfig, TaskType
from voc_calculation import voc_reward_fn
from peft import get_peft_model, LoraConfig, TaskType, PeftModel


class DumbDataset(Dataset):
    def __init__(self, rollout_size=1) -> None:
        super().__init__()
        self.max_steps = rollout_size

    def __len__(self):
        return self.max_steps

    def __getitem__(self, idx):
        return {"prompt": "some prompt",}


def voc_reward(completions, voc_img_buffer, reward_model, processor, voc_reward_kwargs, **kwargs):
    rewards = []
    for temp in voc_img_buffer:
        voc, _ = voc_reward_fn(list(temp), reward_model, processor, **voc_reward_kwargs)
        rewards.append(voc)
        print("VOC: ", voc)
    return rewards
        

@hydra.main(version_base=None, config_path="configs/", config_name="config_grpo.yaml")
def main(args: DictConfig):
    accelerator = Accelerator()
    set_seed(args.grpo.seed, device_specific=True)

    args = OmegaConf.to_container(args, resolve=True, throw_on_missing=True)

    model_path = args['model']['model_path']
    base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_path, torch_dtype=torch.float16
    )
    pretrained_adapter = PeftModel.from_pretrained(base_model, args['model']['lora_path'])
    voc_reward_model = pretrained_adapter.merge_and_unload()
    voc_reward_model.to("cuda:1")
    voc_reward_model.eval()

    os.makedirs(args['general']['save_path'], exist_ok=True)

    mixed_vid_dataset = DumbDataset(rollout_size=args['general']['rollout_size'])

    if args['general']['use_peft']:
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            **args['peft']
        )
    else:
        peft_config = None

    grpo_config = GRPOConfig(
        **args['grpo']
    )

    model_path = args['model']['model_path']
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_path,
        attn_implementation="sdpa"
    )
    processor = AutoProcessor.from_pretrained(model_path)
    processor.eos_token_id = 151645

    reward_fn = voc_reward
        
    trainer = GRPOTrainer(
        model=model,
        processing_class=processor,
        reward_funcs=reward_fn,
        train_dataset=mixed_vid_dataset,
        args=grpo_config,
        peft_config=peft_config,
    )

    trainer.cnt_target_voc = 0
    trainer.cnt_proceed = args['general']['cnt_proceed']
    trainer.target_voc = args['general']['target_voc']
    trainer.cartridge_path = "" #cartridge path in .gb format
    trainer.level_state_path = "" # path to level state
    default_env_class = GenericPyBoyEnv
    trainer.pb = PyBoy
    trainer.env_class = default_env_class
    trainer.voc_reward_model = voc_reward_model
    trainer.voc_buffer_size = args["voc"]["buffer_size"]
    trainer.voc_n_repeats = args['voc']['n_repeats']
    trainer.voc_prompt_version = args['voc']['prompt_version']
    trainer.video_crop_box = args['general']['video_crop_box']
    trainer.gb_crop_box = args['general']['gb_crop_box']
    trainer.max_img_obs_len = args['environment']['max_img_obs_len']

    trainer.train()

if __name__ == "__main__":
    main()