

import wandb

import os
import numpy as np
import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.connectors.env_to_module import FlattenObservations
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.registry import register_env
import gymnasium as gym
import minigrid
from minigrid.wrappers import RGBImgObsWrapper
from minigrid.wrappers import ImgObsWrapper as ImageOnlyObsWrapper
import time
import matplotlib.pyplot as plt
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.policy.sample_batch import concat_samples

from nowarning import *
from ballnd_util import *
from util import explained_variance, mean_absolute_error, compute_utility_sequence

import os
import random
import numpy as np
import torch

torch.manual_seed(0)


SEED = 1
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


from gymnasium.spaces import Dict, Box

class StripMissionWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = Dict({
            "image": env.observation_space["image"],
            "direction": env.observation_space["direction"],
            "action_mask": Box(0, 1, shape=(7,), dtype=np.float32),
        })

    def observation(self, observation):
        mask = np.ones(7, dtype=np.float32)

        return {
            "image": observation["image"],
            "direction": observation["direction"],
            "action_mask": mask,
        }
        
        

from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models import ModelCatalog
import torch.nn as nn

class CNNWithDirection(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        from collections.abc import Iterable as Iterable_
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=2),  
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), 
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2), 
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),  # 3x3 -> 1x1
            nn.ReLU(),
            nn.Flatten()
        )
        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, obs_space["image"].shape[0], obs_space["image"].shape[1])
            cnn_output = self.cnn(dummy_input)
            self._cnn_out_dim = cnn_output.view(1, -1).shape[1]

        direction_dim = model_config.get("custom_model_config", {}).get("direction_dim", 4)
        self.dir_embed = nn.Embedding(direction_dim, 16)
        self.policy_branch = nn.Sequential(nn.Linear(self._cnn_out_dim + 16, num_outputs))
        self.value_branch = nn.Sequential(nn.Linear(self._cnn_out_dim + 16, 1))
        self._value_out = None

    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs"]
        if "mission" in obs:
            obs = {k: v for k, v in obs.items() if k != "mission"}
        img = obs["image"]
        direction = obs["direction"].long()
        if img.dim() == 4 and img.shape[-1] == 3:
            img = img.permute(0, 3, 1, 2).float() / 255.0  # Normalize to [0,1]
        cnn_out = self.cnn(img)
        dir_embed = self.dir_embed(direction)
        x = torch.cat([cnn_out, dir_embed], dim=1)
        self._value_out = self.value_branch(x).squeeze(1)
        logits = self.policy_branch(x)
        
        mask = obs.get("action_mask", None)
        if mask is not None:
            mask = mask.to(logits.device)
            logits = logits
        return logits, state

    def value_function(self):
        return self._value_out


ModelCatalog.register_custom_model("cnn_dir", CNNWithDirection)


class PostprocessCallback(DefaultCallbacks):

    def on_episode_start(self, *, worker, base_env, policies, episode, env_index, **kwargs):
        env = base_env.get_sub_environments()[env_index].unwrapped
        # Find the red ball position and store it
        for row in range(env.grid.width):
            for col in range(env.grid.height):
                cell = env.grid.get(row, col)
                if cell and cell.type == "ball" and cell.color == "red":
                    episode.user_data["red_pos"] = (row, col)
                    break
        # Store the full grid for use in reward shaping
        episode.user_data["grid"] = env.grid

    def on_postprocess_trajectory( self, *, worker, episode, agent_id, policy_id, postprocessed_batch, original_batches, **kwargs,):
        policy = worker.policy_map[policy_id]
        shaping_enabled = policy.config.get("use_shaping", False)
        if not shaping_enabled: return
        dones = postprocessed_batch[SampleBatch.DONES]
        start = 0
        for end in np.where(dones)[0]:
            
            zi, eta = policy.config.get("zi", 0.0), policy.config.get("eta", 1.0)
            red_pos = episode.user_data.get("red_pos")
            obs_seq = {k: v[start:end+1] for k, v in postprocessed_batch["obs"].items()}
            # print("obs_seq keys:", obs_seq)

            pos_dir_traj = episode.user_data.get("agent_trajectory", [])[start:end+1]
            obs_seq["agent_pos"] = [pos for (pos, _, _) in pos_dir_traj]
            obs_seq["agent_dir"] = [dir for (_, dir, _) in pos_dir_traj]
            obs_seq["agent_action"] = [a for (_, _, a) in pos_dir_traj]
            
            agent_trajectory = list(zip(obs_seq["agent_pos"], obs_seq["agent_dir"], obs_seq["agent_action"]))
            raw_u = compute_utility_sequence(agent_trajectory, red_pos)
            

            shaped_A_t = eta * (postprocessed_batch["advantages"][start:end+1])
            mean_abs_A = np.mean(np.abs(shaped_A_t)) + 1e-8
            
            U = raw_u * mean_abs_A
    
            
            postprocessed_batch["advantages"][start:end+1] = shaped_A_t +  zi * U
            start = end + 1
    
    def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs):
        if not hasattr(policy, "_buffered_batches"):
            policy._buffered_batches = []
        
        policy._buffered_batches.append(train_batch)

        # Only process when full batch is accumulated
        total_timesteps = sum(len(batch["rewards"]) for batch in policy._buffered_batches)
        expected_batch_size = policy.config["train_batch_size"]

        if total_timesteps < expected_batch_size:
            return
        elif total_timesteps > expected_batch_size:
            print("Warning: accumulated more than expected batch size; trimming to match.")

        
        device = policy.device
        for i in range(len(policy._buffered_batches)):
            policy._buffered_batches[i] = policy._buffered_batches[i].copy()
            for key, val in policy._buffered_batches[i].items():
                policy._buffered_batches[i][key] = convert_to_torch_tensor(val, device=device)
        # Concatenate all minibatches into one full batch
        full_batch = concat_samples(policy._buffered_batches)

        rewards = full_batch["rewards"]

        values = full_batch[SampleBatch.VF_PREDS]
        returns = full_batch["value_targets"]
        ev = explained_variance(values, returns)
        mae = mean_absolute_error(values, returns)
        # print(f"Explained Variance: {ev:.4f} | MAE: {mae:.4f}")

        wandb.log({
            "shaped/explained_variance": ev,
            "shaped/mae": mae
        })

        shaping_enabled = policy.config.get("use_shaping", False)
        if not shaping_enabled:
            # Push updated advantages back into the minibatches (no shaping, so just propagate as is)
            offset = 0
            for batch in policy._buffered_batches:
                n = len(batch["advantages"])
                batch["advantages"] = full_batch["advantages"][offset:offset + n]
                offset += n
            policy._buffered_batches.clear()
            return
            

        offset = 0
        for batch in policy._buffered_batches:
            n = len(batch["advantages"])
            batch["advantages"] = full_batch["advantages"][offset:offset + n]
            offset += n
        policy._buffered_batches.clear()
        
        
    def on_episode_step(self, *, worker, base_env, policies, episode, env_index, **kwargs):

        agent_id = "agent0"
        agent_collector = episode._agent_collectors[agent_id]

        obs_entry = agent_collector.buffers["obs"][-1]
        action_entry = agent_collector.buffers["actions"][-1]
        reward_entry = agent_collector.buffers["rewards"][-1]
        # print(agent_collector.buffers.keys())
        # Get the most recent step within this buffer entry
        latest_obs = obs_entry[-1]
        latest_action = action_entry[-1]
        latest_reward = reward_entry[-1]

        state = int(np.argmax(latest_obs)) if isinstance(latest_obs, np.ndarray) else latest_obs
        action = int(latest_action)

        # --- Begin trajectory logging for U computation ---
        env = base_env.get_sub_environments()[env_index].unwrapped
        agent_pos = env.agent_pos
        agent_dir = env.agent_dir
        episode.user_data.setdefault("agent_trajectory", []).append(
            ((int(agent_pos[0]), int(agent_pos[1])), int(agent_dir), int(latest_action))
        )
        # --- End trajectory logging ---

        # print(f"step:{episode.length} | state :{state} |  action : {action} | latest_reward = {latest_reward}")


    
import babyai 

def make_env(_):
    env = gym.make("BabyAI-GoToRedBallNoDists-v0", max_steps=40, render_mode='human')
    obs = env.reset(seed=1)
    env = RGBImgObsWrapper(env)
    env = StripMissionWrapper(env)
    return env


register_env("BabyAI-GoToRedBallNoDists-v0", make_env)

# ─── 3) Main: configure & run with PPOConfig ───────────────────────────────  
if __name__ == "__main__":
    np.random.seed(1)
    random.seed(1)
    torch.manual_seed(1)
   
    ray.init(ignore_reinit_error=True)  
    register_env("BabyAI-GoToRedBallNoDists-v0", make_env)


    ModelCatalog.register_custom_model("cnn_dir", CNNWithDirection)
    config = (
        PPOConfig()
        # Environment & framework  
        .environment(env="BabyAI-GoToRedBallNoDists-v0", env_config={})  
        .framework("torch")  
        .env_runners(
            num_env_runners=1,
            num_envs_per_env_runner=1,
            batch_mode="complete_episodes",
            # env_to_module_connector=lambda env, spaces=None, device=None: FlattenObservations(),
        )  
        .api_stack(enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False)

        .training(
            use_gae = True,
            lambda_=0.95,
            train_batch_size=128,
            minibatch_size= 64,
            num_epochs=1, 
            lr=0.0002,
            gamma=0.99,
            entropy_coeff=0.01, 
            model={
                "vf_share_layers": True,
                "conv_filters": [
                    [32, [5, 5], 2],  
                    [64, [4, 4], 2],  
                    [64, [3, 3], 2],  
                    [64, [3, 3], 1], 
                ],
                "post_fcnet_hiddens": [64],
                "post_fcnet_activation": "relu",
                "use_lstm": False,
                "custom_model_config": {
                    "direction_dim": 4,  # assuming Discrete(4)
                },
                "custom_model": "cnn_dir",
            },
        )  
        .experimental(_disable_preprocessor_api=True)
        .callbacks(PostprocessCallback)

        
    )
    config.seed = 3
    
    wandb.init(project="ppo_minigrid_redballND", config={"zi": 0.0, "eta": 1.0},)
    config_wandb = wandb.config


    returns = []
    shaped_returns = []
    shaped_entropies = []
    shaped_explained_vars = []

    unshaped_returns = []
    unshaped_entropies = []
    unshaped_explained_vars = []

    delta_vs = []
    prev_vf_preds = None
    convergence_step = None
    delta_v_threshold = 0.1  # ε value


    RANGE = 750
    
    
    # Run shaped PPO
    ppo_shaped = config.copy(copy_frozen=False)
    ppo_shaped["zi"] = wandb.config.zi
    ppo_shaped["eta"] = wandb.config.eta
    ppo_shaped.callbacks(PostprocessCallback)
    ppo_shaped["use_shaping"] = True 
    ppo_shaped["experiment_name"] = "PPO_shaped"
    algo_shaped = ppo_shaped.build()


    for i in range(RANGE):
        random.seed(1)
        np.random.seed(1)
        torch.manual_seed(1)

        res = algo_shaped.train()
        kl_div = res.get("info", {}).get("learner", {}).get("default_policy", {}).get("learner_stats", {}).get("kl", None)
        if kl_div is not None:
            wandb.log({"shaped/kl_divergence": kl_div})
        ret = res.get("env_runners", {}).get("episode_return_mean", None)
        learner_info = res.get("info", {}).get("learner", {}).get("default_policy", {})
        policy_loss = learner_info.get("learner_stats", {}).get("policy_loss", learner_info.get("loss", None))
        value_loss = learner_info.get("learner_stats", {}).get("vf_loss", learner_info.get("value_function_loss", None))
        print(f"[SHAPED] Iter {i+1:2d} | Return: {ret:.2f}")
        # Compute EV and MAE
        learner_values = res.get("info", {}).get("learner", {}).get("default_policy", {}).get("learner_stats", {})
        vf_preds = learner_values.get("vf_preds", None)
        value_targets = learner_values.get("value_targets", None)
        if vf_preds is not None and value_targets is not None:
            ev = explained_variance(np.array(vf_preds), np.array(value_targets))
            mae = mean_absolute_error(np.array(vf_preds), np.array(value_targets))
        else:
            ev = None
            mae = None
        
        if prev_vf_preds is not None and vf_preds is not None:
            delta_V = np.mean(np.abs(np.array(vf_preds) - np.array(prev_vf_preds)))
            delta_vs.append(delta_V)
            wandb.log({"shaped/delta_V": delta_V})
            if convergence_step is None and delta_V < delta_v_threshold:
                convergence_step = i
                wandb.log({"shaped/convergence_step": i + 1})
        prev_vf_preds = vf_preds
        
        wandb.log({
            "shaped/return": ret,
            "shaped/policy_loss": policy_loss,
            "shaped/value_loss": value_loss,
            # "shaped/explained_variance": ev,
            # "shaped/mae": mae,
            "shaped/iteration": i + 1
        })
        shaped_returns.append(ret)

    wandb.log({
        "shaped/final_return": np.mean(shaped_returns[-10:])
    })

   


    algo_shaped.stop()
    
    ray.shutdown()

