import math
from typing import Any, Dict, List, Optional, Tuple

import gym
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from popgym.baselines.ray_models.base_model import BaseModel
from ray.rllib.algorithms.callbacks import DefaultCallbacks
import wandb

from .lif_gate import StackedLIFGate


class LoggingCallback(DefaultCallbacks):
    def on_train_result(self, *, algorithm, result, **kwargs):
        model = algorithm.get_policy().model
        if isinstance(model, RayLIFGate):
            core = model.lif


class RayLIFGate(BaseModel):
    MODEL_CONFIG = {
        "hidden_size": 128,
        "memory_size": 32,
        "context_size": 8,
        "record_stats": True,
        "agg_kwargs": {},
    }
    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: Dict[str, Any],
        name: str,
        **custom_model_kwargs,
    ):
        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        input_size = self.cfg["preprocessor_output_size"]
        hidden_size = self.cfg["hidden_size"]
        output_size = self.cfg["post_size"]

        self.hidden_size = hidden_size

        self.lif = StackedLIFGate(input_size, hidden_size, output_size)

    def initial_state(self) -> List[torch.Tensor]:
        return [torch.zeros(1, self.hidden_size * 4)]


    def forward_memory(
        self,
        z: torch.Tensor,
        state: List[torch.Tensor],
        t_starts: torch.Tensor,
        seq_lens: torch.Tensor,
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        mem = state[0].permute(1, 0, 2)
        z = z.permute(1, 0, 2)
        x = z.clone()

        out, mem = self.lif(x, mem)

        state = [mem.permute(1, 0, 2)]
        out = out.permute(1, 0, 2)
        return out, state