# %%
from typing import Callable, Tuple

import ipywidgets as widgets
import plotly.graph_objects as go
import torch
import torch.distributions as tdist
import torch.nn as nn
from plotly.subplots import make_subplots

from learned_planners.notebooks.emacs_plotly_render import set_plotly_renderer

set_plotly_renderer("emacs")
torch.set_grad_enabled(False)


# %%
def make_heatmap(
    fn: Callable[[torch.Tensor], torch.Tensor],
    other_input: torch.Tensor,
    xlim: Tuple[float, float] = (-1, 1),
    ylim: Tuple[float, float] = (-1, 1),
    n: int = 100,
):
    x_range = torch.linspace(*xlim, n)
    y_range = torch.linspace(*ylim, n)
    x, y = torch.meshgrid(x_range, y_range, indexing="ij")
    other_tensor_expanded = other_input.expand(n, n, -1)
    cat_inputs = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1), other_tensor_expanded], dim=-1)
    z = fn(cat_inputs.to(DEVICE)).squeeze(dim=-1).cpu().numpy()

    return go.Heatmap(z=z, x=x_range, y=y_range, colorscale="viridis"), x, y


# %%
class RBFModule(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        # self.query = nn.Linear(input_dim, hidden_dim)
        self.value = nn.Linear(input_dim, hidden_dim * output_dim, bias=True)
        self.keys = nn.Parameter(torch.zeros(hidden_dim, input_dim))
        self.re_initialize()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

    def forward(self, x):
        diffs = x[..., None, :] - self.keys
        assert diffs.shape == (*x.shape[:-1], self.hidden_dim, self.input_dim)
        scores = torch.einsum("...hj,...hj->...h", diffs, diffs)
        # dists = scores.mul(-4).exp()
        dists = nn.functional.softmax(scores.sqrt().mul(-40), dim=-1)  # soft Voronoi diagram

        v = self.value(x).reshape((*x.shape[:-1], self.hidden_dim, self.output_dim))
        outs = torch.einsum("...h,...ho->...o", dists, v)
        return outs

    def re_initialize(self):
        self.keys.copy_(torch.rand_like(self.keys).mul_(3).sub_(1.5))
        # self.value.bias.copy_(tdist.StudentT(5, torch.zeros(()), torch.ones(()) ).rsample(self.value.bias.shape))
        # self.value.bias.copy_(torch.arange(self.value.bias.size(0)))
        self.value.bias.copy_(torch.rand(self.value.bias.size()) * 4 - 2)
        self.value.weight.copy_(torch.zeros_like(self.value.weight))


# %%
class Sin(nn.Module):
    def forward(self, x):
        return torch.sin(x)


# %% Try various NNs
N_DIM = 8
HIDDEN_DIM = 16 * (2**N_DIM)
print(f"{HIDDEN_DIM=}")
assert HIDDEN_DIM < 1e6

DEVICE = torch.device("cuda")

if False:
    model = (
        nn.Sequential(
            nn.Linear(2 + N_DIM, HIDDEN_DIM),
            nn.Tanh(),
            nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
            nn.Tanh(),
            nn.Linear(HIDDEN_DIM, 1),
        )
        .to(DEVICE)
        .eval()
    )
elif True:
    model = RBFModule(2 + N_DIM, HIDDEN_DIM, 1).to(DEVICE).eval()
elif False:
    model = (
        nn.Sequential(
            nn.Linear(2 + N_DIM, HIDDEN_DIM),
            Sin(),
            nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
            Sin(),
            nn.Linear(HIDDEN_DIM, 1),
        )
        .to(DEVICE)
        .eval()
    )
elif False:
    model = torch.load("first_nonlinear.pt").to(DEVICE).eval()
else:
    raise NotImplementedError

set_plotly_renderer("notebook")

trace, x, y = make_heatmap(model, torch.zeros(N_DIM))
layout = dict(
    title="Reward Function",
    xaxis_title="x",
    yaxis_title="y",
    xaxis=dict(
        scaleanchor="y",
        scaleratio=1,
    ),
)

fig = go.FigureWidget(go.Figure(data=trace, layout=layout))
if hasattr(model, "keys") and N_DIM < 4:
    fig.add_trace(
        go.Scatter(
            x=model.keys[:, 1].detach().cpu(), y=model.keys[:, 0].detach().cpu(), mode="markers", marker_color="magenta"
        )
    )

sliders = [widgets.FloatSlider(min=-4, max=4, step=0.02, continuous_update=False) for _ in range(N_DIM)]


def changed(event):
    with fig.batch_update():
        n = len(x)
        other_input = torch.tensor([s.value for s in sliders])
        other_tensor_expanded = other_input.expand(n, n, -1)
        cat_inputs = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1), other_tensor_expanded], dim=-1)
        fig.data[0].z = model(cat_inputs.to(DEVICE)).squeeze(-1).cpu().numpy()
        if hasattr(model, "keys") and N_DIM < 4:
            fig.data[1].y = model.keys[:, 0].detach().cpu()
            fig.data[1].x = model.keys[:, 1].detach().cpu()


for s in sliders:
    s.observe(changed)


def reset_parameters(event):
    if hasattr(model, "re_initialize"):
        model.re_initialize()
    else:
        for p in model.parameters():
            p_new = tdist.Normal(torch.zeros(()), torch.ones(()) * 0.02).rsample(p.size())
            p.copy_(p_new)
    changed(event)


def reset_sliders(event):
    values = torch.rand(len(sliders)) * 4 - 2
    for s, v in zip(sliders, values):
        s.value = v.item()
    changed(event)


randomize = widgets.Button(description="Randomize")
randomize.on_click(reset_sliders)
reinitialize = widgets.Button(description="Re-initialize")
reinitialize.on_click(reset_parameters)
widgets.display(fig, *sliders, randomize, reinitialize)

# %%
train_x = model.keys.detach().clone()
train_y = model.value.bias.reshape((model.hidden_dim, model.output_dim)).detach().clone()

NN_HIDDEN_DIM = 64

train_model = (
    nn.Sequential(
        nn.Linear(2 + N_DIM, NN_HIDDEN_DIM),
        nn.Tanh(),
        nn.Linear(NN_HIDDEN_DIM, 1),
    )
    .to(DEVICE)
    .train()
)

trace, *_ = make_heatmap(train_model, torch.zeros(N_DIM))

go.Figure(
    trace,
    layout=dict(
        xaxis=dict(
            scaleanchor="y",
            scaleratio=1,
        )
    ),
)

# %%
with torch.set_grad_enabled(True):
    opt = torch.optim.Adam(train_model.parameters(), lr=0.001)
    criterion = nn.MSELoss()
    for i in range(10000):
        opt.zero_grad()
        preds = train_model(train_x)
        loss = criterion(preds, train_y)
        if i % 100 == 0:
            print(f"Loss: {loss.item()}")
        loss.backward()
        opt.step()

# %%
trace, *_ = make_heatmap(train_model, torch.zeros(N_DIM))
trace2, *_ = make_heatmap(model, torch.zeros(N_DIM))


fig = make_subplots(rows=1, cols=2)

fig.update_layout(
    dict(
        xaxis=dict(
            scaleanchor="y",
            scaleratio=1,
        )
    )
)
fig.add_trace(trace, row=1, col=1)
fig.add_trace(trace2, row=1, col=2)
fig

# %%
set_plotly_renderer("emacs")

trained_model = train_model.eval().to(DEVICE)

trace, x, y = make_heatmap(trained_model, torch.zeros(N_DIM))
layout = dict(
    title="Reward Function",
    xaxis_title="x",
    yaxis_title="y",
    xaxis=dict(
        scaleanchor="y",
        scaleratio=1,
    ),
)

fig = go.Figure(data=trace, layout=layout)
fig

# %%

set_plotly_renderer("notebook")
fig = go.FigureWidget(fig)
sliders = [widgets.FloatSlider(min=-4, max=4, step=0.02, continuous_update=False) for _ in range(N_DIM)]


def changed(event):
    with fig.batch_update():
        n = len(x)
        other_input = torch.tensor([s.value for s in sliders])
        other_tensor_expanded = other_input.expand(n, n, -1)
        cat_inputs = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1), other_tensor_expanded], dim=-1)
        fig.data[0].z = model(cat_inputs.to(DEVICE)).squeeze(-1).cpu().numpy()


for s in sliders:
    s.observe(changed)


def reset_sliders(event):
    values = torch.rand(len(sliders)) * 4 - 2
    for s, v in zip(sliders, values):
        s.value = v.item()
    changed(event)


randomize = widgets.Button(description="Randomize")
randomize.on_click(reset_sliders)
widgets.display(fig, *sliders, randomize, reinitialize)

# %%
# torch.save(trained_model.cpu(), "hard_reward.pt")
