from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from pyboy_gym import GenericPyBoyEnv
from pyboy import PyBoy
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
import torch.optim as optim
from collections import deque
import os
import numpy as np
import time
import torch
from tqdm import tqdm
from functools import partial
import random

from upload_qwen_to_s3 import upload_to_s3
from voc_calculation import voc_reward_fn

from a2c_ppo_acktr import algo, utils, rl_utils
from a2c_ppo_acktr.rl_utils import text_projection
from a2c_ppo_acktr.storage import RolloutStorage
from a2c_ppo_acktr.model import VLMPolicy, VLMValue

from omegaconf import DictConfig, OmegaConf
import hydra
import wandb

import accelerate 
from accelerate.state import AcceleratorState
from accelerate.utils import set_seed
import copy
from a2c_ppo_acktr.utils import RunningNormalizer, AdaptiveKLController


@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(args: DictConfig):
    accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.training.grad_accum_steps)
    device = accelerator.device
    set_seed(args.general.seed, device_specific=False)
    # Load model
    model_path = args.model.model_path
    base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_path, torch_dtype=torch.float16
    )

    # Load lora weight and merge
    if args.voc.use_pretrained:
        pretrained_adapter = PeftModel.from_pretrained(base_model, args.model.lora_path)
        base = pretrained_adapter.merge_and_unload()
    else:
        base = copy.deepcopy(base_model)

    # load pretrained reward model if needed
    if args.voc.use_different_reward_model:
        pretrained_adapter = PeftModel.from_pretrained(copy.deepcopy(base_model), args.model.lora_path)
        voc_reward_model = pretrained_adapter.merge_and_unload()
        voc_reward_model.to('cuda:1')
        voc_reward_model.eval()


    processor = AutoProcessor.from_pretrained(model_path)
    reference_model = None

    AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = 1
    model_device = device

    # apply another lora for rl training
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        **args.peft
    )
    base = get_peft_model(base, peft_config)

    # get KL controller
    if args.training.kl_ctl:
        kl_controller = AdaptiveKLController(init_kl_coef=args.training.kl_beta, 
                                            target=args.training.target_kl,
                                            horizon=args.training.kl_horizon,
                                            kl_beta_lb=args.training.kl_beta_lb)
    else:
        kl_controller = None

    # define value model
    value_model = VLMValue(base)
    value_model = value_model.to(model_device)

    # Get pyboy emulator and env
    pyboy = PyBoy(args.general.cartridge_path)
    env = GenericPyBoyEnv(pyboy=pyboy, state_path=args.general.level_path, max_episode_steps=args.environment.max_episode_steps)

    # define critic
    projection_f = partial(text_projection, env_name=args.environment.env_name)

    actor_critic = VLMPolicy(
        accelerator=accelerator,
        processor=processor,
        value_model=value_model,
        reference_model=reference_model,
        projection_f=projection_f,
        args=args
    )

    # now start learning
    optimizer = optim.Adam([
        {'params': actor_critic.value_model.base_model.parameters(), 'lr': args.training.init_lr},
        {'params': actor_critic.value_model.value_head.parameters(), 'lr': 1e-4}
    ], eps=args.training.eps, weight_decay=args.training.weight_decay)

    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.training.lr_max_steps, eta_min=args.training.end_lr)
    
    actor_critic, optimizer, lr_scheduler = accelerator.prepare(actor_critic, optimizer, lr_scheduler)
    if reference_model is not None:
        reference_model = accelerator.prepare_model(reference_model, evaluation_mode=True)
        actor_critic.reference_model = reference_model

    agent = algo.PPO(
        actor_critic,
        optimizer,
        accelerator,
        args.training.clip_param,
        args.training.ppo_epoch,
        args.training.mini_batch_size,
        args.training.value_loss_coef,
        args.training.entropy_coef,
        args.training.kl_beta,
        max_grad_norm=args.training.max_grad_norm,
        grad_accum_steps=args.training.grad_accum_steps, 
        kl_adapter=kl_controller
    )

    rollouts = RolloutStorage(args.environment.num_steps, 1, env.action_space, 
                              args.model.max_new_tokens, 
                              log_path=os.path.join(args.general.log_dir, args.environment.env_name, "rollouts"))

    # Track episode stats
    running_episode_rewards = torch.zeros(args.environment.num_processes).flatten()
    episode_rewards = deque(maxlen=args.environment.eval_num_per_episode)
    episode_success_rate = deque(maxlen=args.environment.eval_num_per_episode)

    # wandb run name and init
    run_name_prefix = "debug-" if args.general.debug else ""
    run_name = run_name_prefix + args.wandb.wandb_run
    if args.wandb.use_wandb:
        wandb.init(project=args.wandb.wandb_project, name=run_name, group=run_name, config=OmegaConf.to_container(
                                                                                           args, resolve=True, throw_on_missing=True))

    # main loop logic
    start = time.time()
    num_updates = int(args.environment.num_env_steps) // args.environment.num_steps // args.environment.num_processes
    image_observations = deque(maxlen=args.environment.max_image_obs_len)
    prev_actions = deque(maxlen=args.environment.max_image_obs_len - 1)
    infos = []
    action_list = [
        'MOVE_UP',
        'MOVE_DOWN',
        'MOVE_LEFT',
        'MOVE_RIGHT'
    ]

    # --- State ---
    voc_images_buffer = deque(maxlen=args.voc.buffer_size)
    phi_t = 0.0            # current per-step potential estimate
    phi_anchor = 0.0       # last anchored (normalized) potential
    phi_s = []
    dense_rewards = []
    anchors = []
    all_vocs_std, voc_rewards, all_vocs = [], [], []
    cur_traj_step = 0

    # Running stats to keep scale stable across time/episodes
    norm = RunningNormalizer()

    for j in tqdm(range(num_updates)):
        phi_s.clear(), dense_rewards.clear(), anchors.clear(), all_vocs_std.clear(), voc_rewards.clear()
        for step in tqdm(range(args.environment.num_steps), desc="Step on epoch {}, on rank {}".format(j, accelerator.process_index)):
            # Helper to produce a deterministic per-call RNG seed for VOC shuffles
            def _voc_seed(base_seed: int, update_idx: int, step_idx: int) -> int:
                return int(base_seed + update_idx * 1000000 + step_idx)
            # Helper to compute VOC reward with consistent arguments
            def compute_voc_reward(images_buffer, update_idx: int, step_idx: int):
                with torch.no_grad():
                    if args.voc.use_different_reward_model:
                        return voc_reward_fn(
                            list(images_buffer),
                            voc_reward_model,
                            processor,
                            args.voc.n_repeats,
                            context_len=args.voc.context_len,
                            shuffle_context=args.voc.shuffle_context,
                            use_percentage=args.voc.use_percentage,
                            context_path=args.voc.context_path,
                            context_offset=args.voc.context_offset,
                            crop_box=args.general.video_crop_box,
                            prompt_version=args.voc.prompt_version,
                            rng_seed=_voc_seed(args.general.seed, update_idx, step_idx),
                        )
                    else:
                        return voc_reward_fn(
                            list(images_buffer),
                            actor_critic.value_model.base_model,
                            processor,
                            args.voc.n_repeats,
                            context_len=args.voc.context_len,
                            shuffle_context=args.voc.shuffle_context,
                            use_percentage=args.voc.use_percentage,
                            context_path=args.voc.context_path,
                            context_offset=args.voc.context_offset,
                            crop_box=args.general.video_crop_box,
                            prompt_version=args.voc.prompt_version,
                            rng_seed=_voc_seed(args.general.seed, update_idx, step_idx),
                        )
            # Sample actions
            if j == 0 and step == 0:
                image_obs = env.reset()[0][0].copy().crop(args.general.gb_crop_box) # cut off time from frame
                image_observations.append(image_obs.copy())
                voc_images_buffer.append(image_obs.copy().resize((args.general.video_crop_box[2], args.general.video_crop_box[3])))
                cur_traj_step = 1
                observation = rl_utils.make_observation(image_observations, prev_actions)
                rollouts.obs[0] = observation
        
            value, output_id, action, tokens_log_probs, text_action = actor_critic.act(observation)
            reference_log_probs = actor_critic.get_reference_model_logits(observation, output_id)

            image_obs, reward, done, _, infos = env.step(action)
            prev_actions.append(action_list[action])
            image_obs = image_obs[0].copy().crop(args.general.gb_crop_box) # cut off time from frame
            if isinstance(reward, np.ndarray):
                reward = torch.from_numpy(reward)
            elif isinstance(reward, list):
                reward = torch.Tensor(reward[0])
            else:
                reward = torch.Tensor(reward)
            accelerator.print("REWARD: ", reward)


            if cur_traj_step % args.voc.reward_steps == 0 and len(voc_images_buffer) == args.voc.buffer_size and args.voc.use_voc:
                voc_reward, all_vocs = compute_voc_reward(voc_images_buffer, j, step)
                # cut off reward if voc is low (thus trajectory is bad)
                if voc_reward < args.voc.threshold:
                    voc_reward = args.voc.penalty_value

                norm.update(voc_reward)
                phi_anchor = norm.normalize(voc_reward)

                voc_rewards.append(voc_reward)
                all_vocs_std.append(np.std(all_vocs))
                anchors.append(phi_anchor)

                accelerator.print("VOC REWARD: ", voc_reward)
            else:
                voc_reward = 0
            
            for d in done:
                if d:
                    if len(voc_images_buffer) > 0: # some images are left without reward so we give it whether we have needed length of steps or not
                        voc_reward, all_vocs = compute_voc_reward(voc_images_buffer, j, step)

                    image_observations.clear(), voc_images_buffer.clear(), prev_actions.clear()
                    cur_traj_step = 0
                    accelerator.print(step, "Episode finished")

            for info in infos:
                if info.get('TimeLimit.truncated', 0):
                    cur_traj_step = 0
                    image_observations.clear(), voc_images_buffer.clear(), prev_actions.clear()
                    accelerator.print(step, "Episode finished")

            image_observations.append(image_obs.copy())
            voc_images_buffer.append(image_obs.copy().resize((args.general.video_crop_box[2], args.general.video_crop_box[3])))
            cur_traj_step += 1
            observation = rl_utils.make_observation(image_observations, prev_actions)

            if args.voc.use_dense_rewards:
                # 2) Exponential interpolation toward the latest anchor (dense φ_t)
                #    Move a small step ETA toward phi_anchor each environment step.
                phi_next = phi_t + args.voc.eta * (phi_anchor - phi_t)
                phi_s.append(phi_next)

                # 3) Potential-based shaping: r_t = γ φ_{t+1} − φ_t
                r_t = args.voc.gamma * phi_next - phi_t
                r_t = np.clip(r_t, -1, 1)
                dense_rewards.append(r_t)
                phi_t = phi_next

            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
            
            running_episode_rewards += reward.flatten()
            for i, d, _ in zip(range(args.environment.num_processes), done, reward):
                if d:
                    episode_rewards.append(running_episode_rewards[i].item())
                    episode_success_rate.append(1 if running_episode_rewards[i] > 0 else 0)
                    running_episode_rewards[i] = 0

            bad_masks = torch.FloatTensor([[0.0] if info.get('TimeLimit.truncated', 0) else [1.0] for info in infos])

            # we only add voc_reward (or r_t) to rollout, so reward from the env doesn't affect the learning
            if args.voc.use_dense_rewards:
                rollouts.insert(observation, output_id, action, tokens_log_probs, reference_log_probs, value, r_t, masks, bad_masks)
            else:
                rollouts.insert(observation, output_id, action, tokens_log_probs, reference_log_probs, value, voc_reward, masks, bad_masks)
            

        accelerator.print("****** iteration number:{} ******".format(j))
        accelerator.print("reward:{}".format(episode_rewards))
        accelerator.print("")
        next_value = actor_critic.get_value(rollouts.obs[-1]).detach()
        accelerator.print(rollouts.obs[-1], next_value)


        rollouts.compute_returns(next_value, args.training.use_gae, args.training.gamma,
                                 args.training.gae_lambda, args.training.use_proper_time_limits)
        value_warmup_enabled = (args.training.value_warmup == "yes")
        only_value_loss = (j < 2) if value_warmup_enabled else False

        kl = (args.training.use_kl == "yes")
        value_loss, action_loss, dist_entropy, value_losses, action_losses, kls, kl_ctl, advantage_stats = agent.update(rollouts, only_value_loss=only_value_loss, kl=kl)
        lr_scheduler.step()

        rollouts.after_update(j)

        if len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.environment.num_processes * args.environment.num_steps
            end = time.time()
        
            accelerator.print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.2f}/{:.2f}, min/max reward {:.2f}/{:.2f}, success_rate {:.2f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), np.mean(episode_success_rate),
                        dist_entropy, value_loss, action_loss))
            if args.wandb.use_wandb:
                wandb.log(
                    {
                        "iteration": j,
                        "num_timesteps": total_num_steps,
                        "FPS": int(total_num_steps / (end - start)),
                        "episode_reward/mean": np.mean(episode_rewards),
                        "episode_reward/median": np.median(episode_rewards),
                        "episode_reward/min": np.min(episode_rewards),
                        "episode_reward/max": np.max(episode_rewards),
                        "episode_success_rate/mean": np.mean(episode_success_rate),
                        "distribution_entropy": dist_entropy,
                        "value/loss": value_loss,
                        "action/loss": action_loss,
                        "reward/max": rollouts.rewards.max().item(),
                        "reward/min": rollouts.rewards.min().item(),
                        "reward/mean": rollouts.rewards.mean().item(),
                        "reward/std": rollouts.rewards.std().item(),
                        "reward/median": rollouts.rewards.median().item(),
                        "return/max": rollouts.returns.max().item(),
                        "return/min": rollouts.returns.min().item(),
                        "return/mean": rollouts.returns.mean().item(),
                        "return/std": rollouts.returns.std().item(),
                        "value/max": rollouts.value_preds.max().item(),
                        "value/min": rollouts.value_preds.min().item(),
                        "value/mean": rollouts.value_preds.mean().item(),
                        "value/std": rollouts.value_preds.std().item(),
                        "kl": np.array(kls).mean(),
                        "voc/mean": np.array(voc_rewards).mean(),
                        "voc/std": np.array(voc_rewards).std(),
                        "voc/shuffle_voc_std/mean": np.array(all_vocs_std).mean(),
                        "voc/shuffle_voc_std/std": np.array(all_vocs_std).std(),
                        "voc/phi_t/mean": np.array(phi_s).mean(),
                        "voc/phi_t/std": np.array(phi_s).std(),
                        "voc/ema_rewards/mean": np.array(dense_rewards).mean(),
                        "voc/ema_rewards/std": np.array(dense_rewards).std(),
                        "voc/phi_anchor/mean": np.array(anchors).mean(),
                        "voc/phi_anchor/std": np.array(anchors).std(),
                        "kl_ctl": kl_ctl,
                        # Advantage statistics
                        "advantage/raw_mean": advantage_stats['raw_mean'],
                        "advantage/raw_std": advantage_stats['raw_std'],
                        "advantage/raw_min": advantage_stats['raw_min'],
                        "advantage/raw_max": advantage_stats['raw_max'],
                        "advantage/raw_median": advantage_stats['raw_median'],
                        "advantage/normalized_mean": advantage_stats['normalized_mean'],
                        "advantage/normalized_std": advantage_stats['normalized_std'],
                        "advantage/normalized_min": advantage_stats['normalized_min'],
                        "advantage/normalized_max": advantage_stats['normalized_max'],
                        "advantage/normalized_median": advantage_stats['normalized_median'],
                        "advantage/zero_ratio": advantage_stats['zero_advantage_ratio'],
                        "advantage/small_ratio": advantage_stats['small_advantage_ratio'],
                        # Additional diagnostic metrics
                        "advantage/td_error_mean": (rollouts.returns[:-1] - rollouts.value_preds[:-1]).mean().item(),
                        "advantage/td_error_std": (rollouts.returns[:-1] - rollouts.value_preds[:-1]).std().item(),
                        "advantage/value_pred_std": rollouts.value_preds[:-1].std().item(),
                        "advantage/returns_std": rollouts.returns[:-1].std().item(),
                    }
                )


if __name__ == "__main__":
    main()