# -*- coding: utf-8 -*-

from typing import Optional

import torch


def naive_recurrent_hgrn(
    x: torch.Tensor,
    g: torch.Tensor,
    initial_state: Optional[torch.Tensor] = None,
    output_final_state: Optional[bool] = False
) -> torch.Tensor:
    dtype = x.dtype
    x, g = map(lambda i: i.float(), (x, g))
    B, T, D = x.shape

    h = torch.zeros(B, D, dtype=torch.float, device=x.device)
    o = torch.zeros_like(x)

    final_state = None
    if initial_state is not None:
        h += initial_state

    for i in range(T):
        h = g[:, i].exp() * h + x[:, i]
        o[:, i] = h

    if output_final_state:
        final_state = h
    return o.to(dtype), final_state


def naive_chunk_hgrn(
    x: torch.Tensor,
    g: torch.Tensor,
    initial_state: Optional[torch.Tensor] = None,
    output_final_state: Optional[bool] = False,
    chunk_size: int = 64
) -> torch.Tensor:
    dtype = x.dtype
    x, g = map(lambda i: i.float(), (x, g))
    B, T, D = x.shape

    gc = g.view(B, chunk_size, D).cumsum(-2).view_as(g)
    h = torch.zeros(B, D, dtype=torch.float, device=x.device)
    o = torch.zeros_like(x)

    final_state = None
    if initial_state is not None:
        h += initial_state

    for i in range(0, T, chunk_size):
        hp = h
        h = torch.zeros(B, D, dtype=torch.float, device=x.device)
        for j in range(i, i + chunk_size):
            h = g[:, j].exp() * h + x[:, j]
            o[:, j] = hp * gc[:, j].exp() + h
        h = o[:, j].clone()

    if output_final_state:
        final_state = h
    return o.to(dtype), final_state
