import math

import torch
from einops import reduce
from einops.layers.torch import Rearrange

from WM import WM
from net.memory.base import BaseMemory

tstr_min = 5
tstr_max = 100
n_taus = 20
k = 8


class SITHMemory(BaseMemory):
    def __init__(self, latent_size, use_F=False):
        super().__init__(latent_size)

        self.use_F = use_F
        self.output_size = n_taus * latent_size

        self.sith = WM(tstr_min, tstr_max, n_taus, k=k, g=0, batch_first=True, dt=1, use_sub=False)
        self.relu = torch.nn.ReLU()
        self.flatten = Rearrange('b t f tau -> b t (f tau)')

    def forward(self, z, h):
        z = z.double()

        if h is None:
            # give input for initial step
            init_input = torch.ones_like(z[:, (0,), :])
            init_alpha = torch.ones_like(z[:, (0,), :])
            _, h, _ = self.sith(init_input, h, alpha=init_alpha)

        til_f, h, F = self.sith(torch.zeros_like(z), h, alpha=z)
        o = F if self.use_F else til_f
        o = self.relu(o.float())
        ctx = self.flatten(o)

        return ctx, h


class SITHSubMemory(BaseMemory):
    def __init__(self, latent_size, add_sum_neurons=True, use_F=False):
        super().__init__(latent_size)

        if add_sum_neurons and use_F:
            raise ValueError("Cannot use F and sum neurons simultaneously")

        self.add_sum_neurons = add_sum_neurons
        self.use_F = use_F

        n_feature_dims = (math.comb(latent_size, 2) * 2) + latent_size
        if self.add_sum_neurons:
            self.output_size = (n_taus + 1) * n_feature_dims
        else:
            self.output_size = n_taus * n_feature_dims

        self.sith = WM(tstr_min, tstr_max, n_taus, k=k, g=0, batch_first=True, dt=1, use_sub=True)
        self.relu = torch.nn.ReLU()
        self.flatten = Rearrange('b t f tau -> b t (f tau)')

    def forward(self, z, h):
        z = z.double()

        if h is None:
            # give input for initial step
            init_input = torch.ones_like(z[:, (0,), :])
            init_alpha = torch.ones_like(z[:, (0,), :])
            _, h, _ = self.sith(init_input, h, alpha=init_alpha)

        til_f, h, F = self.sith(torch.zeros_like(z), h, alpha=z)
        o = F if self.use_F else til_f
        o = self.relu(o.float())

        if self.add_sum_neurons:
            o_sum = reduce(o, 'b t f tau -> b t f 1', 'sum')
            o_sum = torch.where(o_sum > 0.1, 1.0, 0.0)
            o = torch.cat([o, o_sum], dim=-1)  # cat along tau axis

        ctx = self.flatten(o)

        return ctx, h


class SITHSubOnlyMemory(BaseMemory):
    def __init__(self, latent_size, add_sum_neurons=True):
        super().__init__(latent_size)

        self.add_sum_neurons = add_sum_neurons

        n_sub_dims = math.comb(latent_size, 2) * 2

        if self.add_sum_neurons:
            self.output_size = (n_taus + 1) * n_sub_dims
        else:
            self.output_size = n_taus * n_sub_dims

        self.sith = WM(tstr_min, tstr_max, n_taus, k=k, g=0, batch_first=True, dt=1, use_sub=True)
        self.relu = torch.nn.ReLU()
        self.flatten = Rearrange('b t f tau -> b t (f tau)')

    def forward(self, z, h):
        z = z.double()

        if h is None:
            # give input for initial step
            init_input = torch.ones_like(z[:, (0,), :])
            init_alpha = torch.ones_like(z[:, (0,), :])
            _, h, _ = self.sith(init_input, h, alpha=init_alpha)

        o, h, _ = self.sith(torch.zeros_like(z), h, alpha=z)

        o = self.relu(o.float())

        if self.add_sum_neurons:
            o_sum = reduce(o, 'b t f tau -> b t f 1', 'sum')
            o_sum = torch.where(o_sum > 0.1, 1.0, 0.0)
            o = torch.cat([o, o_sum], dim=-1)  # cat along tau axis

        o = o[..., self.latent_size:, :]  # remove first <latent_size> features

        ctx = self.flatten(o)

        return ctx, h


class SITHSubSumOnlyMemory(BaseMemory):
    def __init__(self, latent_size):
        super().__init__(latent_size)

        n_sub_dims = math.comb(latent_size, 2) * 2
        self.output_size = n_sub_dims

        self.sith = WM(tstr_min, tstr_max, n_taus, k=k, g=0, batch_first=True, dt=1, use_sub=True)
        self.relu = torch.nn.ReLU()
        self.flatten = Rearrange('b t f tau -> b t (f tau)')

    def forward(self, z, h):
        z = z.double()

        if h is None:
            # give input for initial step
            init_input = torch.ones_like(z[:, (0,), :])
            init_alpha = torch.ones_like(z[:, (0,), :])
            _, h, _ = self.sith(init_input, h, alpha=init_alpha)

        o, h, _ = self.sith(torch.zeros_like(z), h, alpha=z)

        o = self.relu(o.float())

        o_sum = reduce(o, 'b t f tau -> b t f 1', 'sum')
        o_sum = torch.where(o_sum > 0.1, 1.0, 0.0)
        o = o_sum[..., self.latent_size:, :]  # remove first <latent_size> features

        ctx = self.flatten(o)

        return ctx, h
