import wandb

import os
import numpy as np
import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.connectors.env_to_module import FlattenObservations
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.registry import register_env
import gymnasium as gym
import time
import matplotlib.pyplot as plt
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.policy.sample_batch import concat_samples

from nowarning import *
from util import *
from util import explained_variance, mean_absolute_error, compute_utility_sequence

import os
import random
import numpy as np
import torch

torch.manual_seed(0)


SEED = 42
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False



class PostprocessCallback(DefaultCallbacks):

    def on_postprocess_trajectory( self, *, worker, episode, agent_id, policy_id, postprocessed_batch, original_batches, **kwargs,):
        policy = worker.policy_map[policy_id]
        shaping_enabled = policy.config.get("use_shaping", False)
        if not shaping_enabled:
            return
        # Convert dones to CPU numpy for indexing
        dones = postprocessed_batch[SampleBatch.DONES]
        start = 0
        # For each episode segment in this batch, compute and inject shaped advantages
        for end in np.where(dones)[0]:
            obs_seq = postprocessed_batch["obs"][start:end+1]
            act_seq = postprocessed_batch["actions"][start:end+1]
            raw_u = compute_utility_sequence(obs_seq, act_seq)

            # 1) Center & normalize U so it’s in [-1,1]
            U_norm = raw_u - raw_u.mean()
            max_abs = np.max(np.abs(U_norm)) + 1e-8
            U_norm = U_norm / (np.std(raw_u) + 1e-8)
            
            alpha = policy.config.get("shaping_alpha", 0.0)
            eta = policy.config.get("shaping_eta", 1.0)
            

            # 2) Scale U_norm by the segment's mean-abs advantage
            A_seg = eta * (postprocessed_batch["advantages"][start:end+1])
            mean_abs_A = np.mean(np.abs(A_seg)) + 1e-8
            bonus = raw_u * mean_abs_A

            shaped_bonus = alpha * bonus

            adv = postprocessed_batch["advantages"][start:end+1]
            postprocessed_batch["advantages"][start:end+1] = A_seg +  shaped_bonus
            start = end + 1
    
    def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs):
        if not hasattr(policy, "_buffered_batches"):
            policy._buffered_batches = []
        
        policy._buffered_batches.append(train_batch)

        # Only process when full batch is accumulated
        total_timesteps = sum(len(batch["rewards"]) for batch in policy._buffered_batches)
        expected_batch_size = policy.config["train_batch_size"]

        if total_timesteps < expected_batch_size:
            return
        elif total_timesteps > expected_batch_size:
            print("Warning: accumulated more than expected batch size; trimming to match.")

        
        device = policy.device
        for i in range(len(policy._buffered_batches)):
            policy._buffered_batches[i] = policy._buffered_batches[i].copy()
            for key, val in policy._buffered_batches[i].items():
                policy._buffered_batches[i][key] = convert_to_torch_tensor(val, device=device)
        # Concatenate all minibatches into one full batch
        full_batch = concat_samples(policy._buffered_batches)
        rewards = full_batch["rewards"]



        values = full_batch[SampleBatch.VF_PREDS]
        returns = full_batch["value_targets"]
        ev = explained_variance(values, returns)
        mae = mean_absolute_error(values, returns)
        # print(f"Explained Variance: {ev:.4f} | MAE: {mae:.4f}")

        wandb.log({
            "shaped/explained_variance": ev,
            "shaped/mae": mae
        })

        shaping_enabled = policy.config.get("use_shaping", False)
        if not shaping_enabled:
            # Push updated advantages back into the minibatches (no shaping, so just propagate as is)
            offset = 0
            for batch in policy._buffered_batches:
                n = len(batch["advantages"])
                batch["advantages"] = full_batch["advantages"][offset:offset + n]
                offset += n
            policy._buffered_batches.clear()
            return


        # Push updated advantages back into the minibatches
        offset = 0
        for batch in policy._buffered_batches:
            n = len(batch["advantages"])
            batch["advantages"] = full_batch["advantages"][offset:offset + n]
            offset += n
        policy._buffered_batches.clear()

    
    def on_episode_step(self, *, worker, base_env, policies, episode, env_index, **kwargs):

        agent_id = "agent0"
        agent_collector = episode._agent_collectors[agent_id]

        obs_entry = agent_collector.buffers["obs"][-1]
        action_entry = agent_collector.buffers["actions"][-1]
        reward_entry = agent_collector.buffers["rewards"][-1]
        # print(agent_collector.buffers.keys())
        # Get the most recent step within this buffer entry
        latest_obs = obs_entry[-1]
        latest_action = action_entry[-1]
        latest_reward = reward_entry[-1]

        state = int(np.argmax(latest_obs)) if isinstance(latest_obs, np.ndarray) else latest_obs
        action = int(latest_action)
    

# ─── 3) Main: configure & run with PPOConfig ───────────────────────────────  
if __name__ == "__main__":
    np.random.seed(42)
    random.seed(42)
    torch.manual_seed(42)
   
    ray.init(ignore_reinit_error=True)  
    register_env("FrozenLake-v1", lambda cfg: gym.make("FrozenLake-v1", map_name="8x8", is_slippery=cfg.get("is_slippery", True), render_mode=None))


    # register_env("FrozenLake-v1", lambda cfg: gym.make("FrozenLake-v1", render_mode=None))
    # Build the PPOConfig using the new API  
    config = (
        PPOConfig()
        # Environment & framework  
        .environment(env="FrozenLake-v1", env_config={"is_slippery": True, "map_name": "8x8"},)  
        .framework("torch")  
        .env_runners(
            num_env_runners=1,
            num_envs_per_env_runner=1,
            batch_mode="complete_episodes",
            # env_to_module_connector=lambda env, spaces=None, device=None: FlattenObservations(),
        )  
        .api_stack(enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False)

        .training(
            use_gae = True,
            lambda_=0.95,
            train_batch_size=512,
            minibatch_size= 128,
            num_epochs=2,
            lr=1e-4,
            gamma=0.99,
            entropy_coeff=0.01, 
        )  
    
        .callbacks(PostprocessCallback)

        # .rl_module(model_config={"fcnet_hiddens": [16, 16]})
        
    )
    config.seed = 42
    
    wandb.init(project="ppo_frozenlake", config={"shaping_alpha": 0.0, "shaping_eta": 1.0},)
    config_wandb = wandb.config


    returns = []
    shaped_returns = []
    shaped_entropies = []
    shaped_explained_vars = []

    unshaped_returns = []
    unshaped_entropies = []
    unshaped_explained_vars = []

    delta_vs = []
    prev_vf_preds = None
    convergence_step = None
    delta_v_threshold = 0.1  # ε value


    RANGE = 650
    
    
    # Run shaped PPO
    ppo_shaped = config.copy(copy_frozen=False)
    ppo_shaped["shaping_alpha"] = wandb.config.shaping_alpha
    ppo_shaped["shaping_eta"] = wandb.config.shaping_eta
    ppo_shaped.callbacks(PostprocessCallback)
    ppo_shaped["use_shaping"] = True
    ppo_shaped["experiment_name"] = "PPO_shaped"
    algo_shaped = ppo_shaped.build()

    for i in range(RANGE):
        random.seed(42)
        np.random.seed(42)
        torch.manual_seed(42)
        res = algo_shaped.train()
        kl_div = res.get("info", {}).get("learner", {}).get("default_policy", {}).get("learner_stats", {}).get("kl", None)
        if kl_div is not None:
            wandb.log({"shaped/kl_divergence": kl_div})
        ret = res.get("env_runners", {}).get("episode_return_mean", None)
        learner_info = res.get("info", {}).get("learner", {}).get("default_policy", {})
        # print(f"[DEBUG] Learner info: {learner_info}")
        policy_loss = learner_info.get("learner_stats", {}).get("policy_loss", learner_info.get("loss", None))
        value_loss = learner_info.get("learner_stats", {}).get("vf_loss", learner_info.get("value_function_loss", None))
        
        # Compute EV and MAE
        learner_values = res.get("info", {}).get("learner", {}).get("default_policy", {}).get("learner_stats", {})
        vf_preds = learner_values.get("vf_preds", None)
        value_targets = learner_values.get("value_targets", None)
        if vf_preds is not None and value_targets is not None:
            ev = explained_variance(np.array(vf_preds), np.array(value_targets))
            mae = mean_absolute_error(np.array(vf_preds), np.array(value_targets))
        else:
            ev = None
            mae = None
        

        
        wandb.log({
            "shaped/return": ret,
            # "shaped/policy_loss": policy_loss,
            # "shaped/value_loss": value_loss,
            # "shaped/explained_variance": ev,
            # "shaped/mae": mae,
            "shaped/iteration": i + 1
        })
        shaped_returns.append(ret)
        # Compute and log stability score per iteration
        print(f"[SHAPED] Iter {i+1:2d} | Return: {ret:.2f}")

    wandb.log({
        "shaped/final_return": np.mean(shaped_returns[-10:])
    })

    algo_shaped.stop()
    
    
    #--------ushaped PPO----------
    
    # ppo_unshaped = config.copy(copy_frozen=False)
    # ppo_unshaped.callbacks(PostprocessCallback)
    # ppo_unshaped["use_shaping"] = False
    # ppo_unshaped["experiment_name"] = "PPO_unshaped"
    # algo_unshaped = ppo_unshaped.build()

    # for i in range(RANGE):
    #     random.seed(42)
    #     np.random.seed(42)
    #     torch.manual_seed(42)
    #     res = algo_unshaped.train()
    #     ret = res.get("env_runners", {}).get("episode_return_mean", None)
    #     unshaped_returns.append(ret)
    #     # unshaped_entropies.append(res.get("custom_metrics", {}).get("action_entropy", 0.0))
    #     # unshaped_explained_vars.append(res.get("custom_metrics", {}).get("explained_variance", 0.0))
    #     print(f"[UNSHAPED] Iter {i+1:2d} | Return: {ret:.2f}")
    # algo_unshaped.stop()
    
    

   #---------------------


    ray.shutdown()

    # ------------------------
    # # Roll out and render the trained shaped policy using algo_shaped.compute_single_action()
    # import gymnasium as gym

    # render_env = gym.make("FrozenLake-v1", map_name="4x4", is_slippery=True, render_mode="human")

    # print("\n=== Rendering one shaped episode ===")
    # obs, info = render_env.reset(seed=SEED)
    # done = False
    # while not done:
    #     action = algo_shaped.compute_single_action(obs, explore=False)
    #     obs, reward, terminated, truncated, info = render_env.step(action)
    #     # Render to the human window
    #     render_env.render()
    #     # Brief pause so you can watch the steps
    #     # time.sleep(0.5)
    #     done = terminated or truncated

    # plt.figure(figsize=(10, 6))
    # plt.plot(shaped_returns, label="Shaped Return")
    # plt.plot(unshaped_returns, label="Unshaped Return")
    # plt.title("Episode Return Comparison")
    # plt.xlabel("Iteration")
    # plt.ylabel("Mean Episode Return")
    # plt.legend()
    # plt.grid(True)
    # plt.show()
