from collections import namedtuple
from dataclasses import dataclass
from typing import List

import torch
import torch.nn as nn


ActorCriticOutput = namedtuple('ActorCriticOutput', 'logits_act val hx_cx')


@dataclass
class FrameCnnConfig:
    image_channels: int
    image_size: int
    latent_dim: int
    num_channels: int
    mult: List[int]
    down: List[int]


class ActorCritic(nn.Module):
    def __init__(self, num_actions: int) -> None:
        super().__init__()

        # IRIS's CNN
        # self.cnn = nn.Sequential(
        #     nn.Conv2d( 3, 32, 3, stride=1, padding=1), nn.MaxPool2d(2), nn.ReLU(),
        #     nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.MaxPool2d(2), nn.ReLU(),
        #     nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.MaxPool2d(2), nn.ReLU(),
        #     nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.MaxPool2d(2), nn.ReLU(),
        # )
        self.cnn = FrameEncoder(FrameCnnConfig(3, 64, latent_dim=None, num_channels=32, mult=(1,1,2,2), down=(1,1,1,1)))

        self.lstm_dim = 256
        self.lstm = nn.LSTMCell(1024, self.lstm_dim)
        
        self.critic_linear = nn.Linear(self.lstm_dim, 1)
        self.actor_linear = nn.Linear(self.lstm_dim, num_actions)

        self.actor_linear.weight.data.fill_(0)
        self.actor_linear.bias.data.fill_(0)
        self.critic_linear.weight.data.fill_(0)
        self.critic_linear.bias.data.fill_(0)

        for name, p in self.named_parameters():
            if "lstm" in name:
                if "weight_ih" in name:
                    nn.init.xavier_uniform_(p.data)
                elif "weight_hh" in name:
                    nn.init.orthogonal_(p.data)
                elif "bias_ih" in name:
                    p.data.fill_(0)
                    # Set forget-gate bias to 1
                    n = p.size(0)
                    p.data[(n // 4) : (n // 2)].fill_(1)
                elif "bias_hh" in name:
                    p.data.fill_(0)

    def forward(self, obs, hx_cx):
        x = self.cnn(obs)
        x = x.flatten(start_dim=1)
        hx, cx = self.lstm(x, hx_cx)
        return ActorCriticOutput(self.actor_linear(hx), self.critic_linear(hx).squeeze(dim=1), (hx, cx))
    
    @property
    def device(self) -> torch.device:
        return self.lstm.weight_hh.device


class SmallResBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, num_groups_norm: int = 8) -> None:
        super().__init__()
        self.f = nn.Sequential(
            nn.GroupNorm(num_groups_norm, in_channels), 
            nn.SiLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        )
        self.skip_projection = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.skip_projection(x) + self.f(x) 


class FrameEncoder(nn.Module):
    def __init__(self, config: FrameCnnConfig) -> None:
        super().__init__()
        assert len(config.mult) == len(config.down)
        encoder_layers = [
            nn.Conv2d(config.image_channels, config.num_channels, kernel_size=3, stride=1, padding=1),
        ]
        input_channels = config.num_channels
        for m, d in zip(config.mult, config.down):
            output_channels = m * config.num_channels
            encoder_layers.append(SmallResBlock(input_channels, output_channels))
            input_channels = output_channels
            if d: encoder_layers.append(nn.MaxPool2d(2))
        self.encoder = nn.Sequential(*encoder_layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.encoder(x)
