import dataclasses
from pathlib import Path
from typing import Optional

import optree as ot
import plotly.express as px
import sklearn.linear_model
import torch
import torch.nn.functional as F
import torch.utils.data
import torch.utils.hooks
from stable_baselines3.common.type_aliases import check_cast
from stable_baselines3.ppo import PPO

from learned_planners.notebooks.emacs_plotly_render import set_plotly_renderer
from learned_planners.train import TrainConfig, create_vec_env_and_eval_callbacks

set_plotly_renderer("emacs")

# %%

cfg = TrainConfig()

env, _ = create_vec_env_and_eval_callbacks(cfg, run_dir=Path("."), eval_freq=10)

# This should work so long as gym is installed
model = PPO.load("rl_model_589824000_steps.zip", env=env)
policy = model.policy


# %%
reward_fn = torch.load("learned_planners/notebooks/hard_reward.pt")
rsd = reward_fn.state_dict()

other_rsd = check_cast(torch.nn.Module, policy.features_extractor.reward_fn).state_dict()

for k in rsd.keys():
    assert torch.equal(rsd[k], other_rsd[k])
# %%


def obsh(x):
    return {"obs": x[..., :2], "hidden": x[..., 2:]}


# %%

N_data = 100000
D = 10

X = torch.rand((N_data, D))
with torch.no_grad():
    actual_reward = reward_fn(X)

y = policy(obsh(X))


# %%


@dataclasses.dataclass(frozen=False)
class Cache:
    handle: torch.utils.hooks.RemovableHandle
    children: dict[str, "Cache"] = dataclasses.field(default_factory=dict)
    mod: Optional[torch.nn.Module] = None
    in_args: Optional[tuple[torch.Tensor, ...]] = None
    out: Optional[torch.Tensor] = None

    def set_value_hook(self, mod, args, output):
        self.in_args = args
        self.out = output
        return None

    @classmethod
    def from_single_module(cls, module: torch.nn.Module):
        obj = cls(None)  # type: ignore
        handle = module.register_forward_hook(obj.set_value_hook, with_kwargs=False)
        obj.handle = handle
        obj.mod = module
        return obj

    @classmethod
    def from_module(cls, module: torch.nn.Module):
        obj = cls.from_single_module(module)
        for name, m in module.named_modules():
            if name == "":
                continue  # Do not include self in the children dict
            obj.children[name] = cls.from_single_module(m)
        return obj

    def __del__(self):
        self.handle.remove()

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.handle.remove()

    @property
    def value(self):
        return (
            self.in_args,
            self.out,
            {k: v.value for k, v in self.children.items()},
        )


with torch.no_grad(), Cache.from_module(policy) as cache:
    policy(obsh(X))

    hook_X_obs = cache.value[0][0]["obs"]  # type: ignore
    assert torch.equal(hook_X_obs, obsh(X)["obs"])

    pos_input = obsh(cache.value[2]["features_extractor.reward_fn"][0][0])["obs"]

    internal_reward = cache.value[2]["features_extractor.reward_fn"][1]


# %% Can we predict actual position from the input to the reward function?


def linear_fit(X, y):
    N = int(len(y) * 0.8)

    train_X = X[:N].contiguous().view(N, -1).numpy()
    train_y = y[:N].contiguous().view(N, -1).numpy()

    N_test = len(y) - N
    test_X = X[N:].contiguous().view(N_test, -1).numpy()
    test_y = y[N:].contiguous().view(N_test, -1).numpy()

    # probe = sklearn.linear_model.LinearRegression()
    probe = sklearn.linear_model.Lasso(alpha=0.01)
    probe.fit(train_X, train_y)
    print("Coeff of determination:", probe.score(test_X, test_y))
    return probe


_ = linear_fit(hook_X_obs, pos_input)

# %% Can we predict reward from the internal reward?

predictor = linear_fit(internal_reward, actual_reward)
# Only at 0.89 accuracy from this.
px.line(predictor.coef_.reshape(-1)).show()

_ = linear_fit(hook_X_obs, actual_reward)
# 0.01 acc from this. It's highly nonlinear

_ = linear_fit(hook_X_obs, internal_reward)
# Also 0.01, bit lower (probably noise)

# %% Can we predict reward from all the NN activations?

named_flat_params = {}
for k, v in cache.value[2].items():
    vals, _ = ot.tree_flatten(v)
    named_flat_params.update({id(vv): (k, vv) for vv in vals if vv.dtype == torch.float32})

flat_params = [(n, v.view(N_data, -1)) for n, v in named_flat_params.values()]

# flat_params.append(("actual_reward", actual_reward.view(N_data, -1)))

all_activations = torch.cat([v for _, v in flat_params], dim=1)
predictor = linear_fit(all_activations, actual_reward)


# %%
fig = px.line(predictor.coef_.reshape(-1))
x = 0
for name, v in flat_params:
    fig.add_vline(x=x, line_dash="dash", line_color="black")
    # Print name
    fig.add_annotation(
        x=x,
        y=0,
        text=name,
        showarrow=False,
        yshift=0,
        xshift=-4,
        font=dict(size=8),
        textangle=-90,
    )
    x += v.shape[1]
fig.show()
# Is about 0.89, same as reward only.
# It mostly does use the internal reward! We needed to use LASSO instead of un-regularized though and play with alpha.

# %% If we predict the input from the internal positions, and the reward from the NN activations, how good is the fit we
# obtain?

obs_predictor = linear_fit(pos_input, hook_X_obs)
internal_reward_predictor = linear_fit(internal_reward, actual_reward)

pred_obs = obs_predictor.predict(pos_input.contiguous().view(N_data, -1).numpy())
pred_internal_reward = internal_reward_predictor.predict(internal_reward.contiguous().view(N_data, -1).numpy())

pred_X = torch.cat([torch.from_numpy(pred_obs), obsh(X)["hidden"]], dim=1)

with torch.no_grad():
    t_pred_internal_reward = torch.from_numpy(pred_internal_reward)[:, None]
    residual_sos = F.mse_loss(t_pred_internal_reward, reward_fn(pred_X))
    residual_var = F.mse_loss(reward_fn(pred_X), reward_fn(pred_X).mean(dim=0))
    fvar_explained = 1 - residual_sos / residual_var
    print("Fraction of variance explained:", fvar_explained.item())
    # 0.9, no improvement

# %% What about the other way around?
internal_pos_predictor = linear_fit(hook_X_obs, pos_input)
pred_internal_pos = torch.from_numpy(internal_pos_predictor.predict(hook_X_obs.contiguous().view(N_data, -1).numpy())).view(
    N_data, 5, 2
)

non_observation = obsh(X)["hidden"].unsqueeze(-2).expand(-1, pred_internal_pos.shape[-2], -1)

pred_X = torch.cat([pred_internal_pos, non_observation], dim=-1)

with torch.no_grad():
    residual_sos = F.mse_loss(reward_fn(pred_X), internal_reward)
    residual_var = F.mse_loss(reward_fn(pred_X), reward_fn(pred_X).mean(dim=0))
    fvar_explained = 1 - residual_sos / residual_var
    print("Fraction of variance explained:", fvar_explained.item())
    # 0.968 when predicting internal states, then using the reward function to predict the reward. Very good!


# %% How can we predict these internal states in general?
