if __name__ == '__main__':
    import sys
    import pathlib
    p = pathlib.Path().absolute()
    print("Adding path: ", p)
    sys.path.append(str(p))

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as U
from einops import rearrange, repeat
from omegaconf import DictConfig
import opt_einsum as oe

optimized = True

if optimized:
    contract = oe.contract
else:
    contract = torch.einsum

from src.models.nn.krylov import HippoKrylov
# from models.nn.components import TransposedLinear
from src.models.nn import LinearActivation, Activation
import src.utils.train

log = src.utils.train.get_logger(__name__)

class StateSpace(nn.Module):
    # transposed = True
    requires_length = True

    def __init__(
            self,
            H,
            l_max=None,
            d_model=64, # overloading this term, same as memory_order or N
            measure='legs', # 'legs', 'legt' main ones; can also try 'lagt'
            dt_min=0.001,
            dt_max=0.1,
            trainable=None,
            lr=None,
            rank=1,
            stride=1,
            w_bias=0.0,
            dropout=0.0,
            cache=False,
            weight_decay=0.0,
            # return_state=True,
            transposed=True,
            activation='gelu',
            postact=None,
            weight_norm=False,
            # glu=False,
            initializer=None,
            train_state=False,
            slow=False, # Use slow Krylov function for debugging
            test_resolution=False, # Assume all sequences are same length and different length sequences are subsampled differently
            # absorb_c=True,
        ):
        """
        N: the order of the HiPPO projection
        dt: discretization step size - should be roughly inverse to the length of the sequence
        """
        super().__init__()
        log.info(f"Constructing S3 (H, N, L) = ({H}, {d_model}, {l_max})")

        # assert l_max is not None

        self.h = H
        self.n = d_model if d_model > 0 else H
        self.L = l_max
        # self.return_state = return_state
        self.stride = stride
        if l_max is not None and stride > 1:
            assert l_max % stride == 0
            l_max = l_max // self.stride
        self.cache = cache
        self.weight_decay = weight_decay
        self.transposed = transposed
        self.test_resolution = test_resolution


        self.D = nn.Parameter(torch.randn(self.h))

        self.krylov = HippoKrylov(self.n, self.h, l_max, dt_min=dt_min, dt_max=dt_max, measure=measure, rank=rank, w_bias=w_bias, trainable=trainable, lr=lr, slow=slow, use_length=l_max is not None)
        self.K = None # Cache the computed convolution filter if possible (during evaluation)

        self.activation = Activation(activation)
        dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout
        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()

        # if self.transposed:
        #     self.output_linear = TransposedLinear(self.h, self.h)
        # else:
        #     self.output_linear = nn.Linear(self.h, self.h)
        self.output_linear = LinearActivation(
            self.h,
            self.h,
            transposed=self.transposed,
            initializer=initializer,
            activation=postact,
            activate=True,
            weight_norm=weight_norm,
        )
        # if glu:
        #     self.output_linear = LinearActivation(self.h, self.h, transposed=self.transposed, initializer=initializer, activation='glu', activate=True)
        # else:
        #     self.output_linear = LinearActivation(self.h, self.h, transposed=self.transposed, initializer=initializer)


        _initial_state = torch.zeros(self.h, self.n)
        if train_state:
            self._initial_state = nn.Parameter(torch.zeros(self.h, self.n))
        else:
            self.register_buffer('_initial_state', _initial_state)

        self.train_L = None

    def forward(self, u, state=None, cache=None, **kwargs): # absorbs return_output and transformer src mask
        """
        u: (B H L)

        Returns: (B H L)
        """
        if not self.transposed: u = u.transpose(-1, -2)
        L = u.size(-1)
        if self.L is None and self.training: # Store the length of the latest train batch
            self.train_L = L
        while self.L is not None and L > self.L:
            log.info(f"S3: Doubling length from L = {self.L} to {2*self.L}")
            self.krylov.double_length()
            self.L *= 2

        if state is not None:
            assert self.stride == 1, "Striding not supported with states"
            k, k_state = self.krylov(state=state, L=L)
        else:
            # Cache the convolutional filter if:
            # 1. No state is passed in
            # 2. This forward call asks us to
            # 3. This model is defined with caching and it is a validation phase
            cache = cache or (cache is None and self.cache and not self.training)

            # Cache the filter at different resolution if necessary
            # TODO consolidate caching, state, resolution logic
            # if cache and self.test_resolution:
                # train_L = self.L or self.train_L
                # if train_L is not None and train_L != L:
                #     self.krylov.forward_resolution(train_L/L, L)
            # self.krylov.forward_resolution(self.L/L, L)

            if cache:
                # if self.test_resolution:
                rate = self.L / L
                k = self.krylov._cache(rate=rate)
            else:
                self.krylov._uncache()
                k = self.krylov()

        # (H, L)
        self.K = k # Store this for computing extra weight decay loss

        # Stride the filter if needed
        if self.stride > 1:
            k = k[..., :L // self.stride] # (H, L/S)
            k = F.pad(k.unsqueeze(-1), (0, self.stride-1)) # (H, L/S, S)
            k = rearrange(k, '... h s -> ... (h s)') # (H, L)
        else:
            k = k[..., :L]

        # Convolution
        k_f = torch.fft.rfft(k, n=2*L) # (H L)
        u_f = torch.fft.rfft(u, n=2*L) # (B H L)
        y_f = k_f * u_f
        y = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B H L)

        # Parameter will change during training so make sure to recompute on the next pass
        if self.training: self.krylov._uncache()

        # Compute D term in state space equation
        y = y + u * self.D.unsqueeze(-1)

        # Compute state update
        if state is not None:
            y = y + k_state[..., :L]
            next_state = self.krylov.next_state(state, u)
        else:
            next_state = None


        y = self.dropout(self.activation(y))


        y = self.output_linear(y)


        if not self.transposed: y = y.transpose(-1, -2)

        # if self.return_state:
        return y, next_state
        # else:
        #     return y

    def step(self, u, state):
        """ Step one time step as a recurrent model. Intended to be used during validation.

        u: (B H)
        state: (B H N)
        Returns: output (B H), state (B H N)
        """
        assert not self.training

        # if u is None: return None, state

        # if state is None:
        #     state = u.new_zeros(u.shape + (self.n,))

        # self.krylov.cache()
        # dA, dB, dC = self.krylov.step() # (H N N), (H N), (H N)
        # next_state = contract('h m n, b h n -> b h m', dA, state) + contract('h n, b h -> b h n', dB, u)
        # y = contract('h n, b h n -> b h', dC, next_state)
        y, next_state = self.krylov.step(u, state)
        y = y + u * self.D
        y = self.output_linear(self.activation(y).unsqueeze(-1)).squeeze(-1)
        return y, next_state

    # def new_state(self, u):
    #     """ u: (batch) """
    #     return u.new_zeros(u.shape + (self.h, self.n))
    def default_state(self, *batch_shape, device=None):
        return self._initial_state.repeat(*batch_shape, 1, 1)
        # return torch.zeros(*batch_shape, self.h, self.n, device=device)

    def cache_all(self):
        self.krylov._cache_all()

    def loss(self):
        """ Extra train loss (implements weight decay for the filter).

        This is probably better than naive weight decay on the individual parameters A, B, C, dt, although we have not tested that.
        Prior work that parameterizes convolution filters implicitly (i.e. CKConv) also implement it this way.
        """
        if self.weight_decay > 0.0:
            return self.weight_decay / 2.0 * F.mse_loss(self.K, torch.zeros_like(self.K), reduction='sum')
        else: return 0.0

    @property
    def d_state(self):
        return self.h * self.n

    @property
    def d_output(self):
        return self.h

    @property
    def state_to_tensor(self):
        return lambda state: rearrange('... h n -> ... (h n)', state)


def test_state():
    B = 1
    H = 64
    N = 64
    L = 1024
    s3 = StateSpace(H, d_model=N, l_max=L, slow=False)
    for module in s3.modules():
        if hasattr(module, 'setup'): module.setup()

    u = torch.ones(B, H, L)
    # initial_state = torch.zeros(B, H, N)
    initial_state = torch.randn(B, H, N)

    state = initial_state.clone()
    y, final_state = s3(u, state)
    print("output:\n", y, y.shape)
    print("final state:\n", final_state, final_state.shape)

    state = initial_state.clone()
    chunks = 2
    outs = []
    for u_ in u.chunk(chunks, dim=-1):
        # torch.unbind(u, dim=-1):
        # print("shape", u_.shape)
        y_, state = s3(u_, state=state)
        # y_, state = s3(u_.unsqueeze(-1), state=state)
        outs.append(y_)
        # print("step output:", y_, y_.shape)
        # print("step state:", state, state.shape)
    outs = torch.cat(outs, dim=-1)
    print("step outputs:\n", outs)
    print("step final state:\n", state)
    print("step output error:")
    utils.compare_outputs(y, outs)
    print("step final state error:")
    utils.compare_outputs(final_state, state)

def test_recurrence():
    B = 2
    H = 3
    N = 4
    L = 6
    s3 = StateSpace(H, d_model=N, l_max=L)
    for module in s3.modules():
        if hasattr(module, 'setup'): module.setup()

    u = torch.ones(B, H, L)
    state = torch.zeros(B, H, N)
    y, state = s3(u, state=state)
    print(y, y.shape)
    print("state", state, state.shape)

    # s3.cache_all()

    for module in s3.modules():
        if hasattr(module, 'setup_step'): module.setup_step()
    s3.eval()

    state = s3.default_state(*u.shape[:-2], device=device)
    ys = []
    for u_ in torch.unbind(u, dim=-1):
        y_, state = s3.step(u_, state=state)
        ys.append(y_)
        # print("y", y_, y_.shape)
    y = torch.stack(ys, dim=-1)
    print(y, y.shape)
    print("state", state, state.shape)

if __name__ == '__main__':
    from benchmark import utils

    device = 'cuda' # 'cpu'
    device = torch.device(device)

    # test_state()
    test_recurrence()
