"""
Kernel classes 

Call forward to output the samples
One class to take care of the number of hippos
"""
import math
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat
import opt_einsum as oe

from models.spacetime import OurModule
from models.functional.krylov import krylov

from models.functional.companion_krylov import companion_krylov


class SpaceTimeKernel(OurModule):
    def __init__(self, 
                 channels,
                 d_state,
                 n_hippos,
                 lr,
                 random_a,
                 random_b,
                 random_c,
                 random_d,
                 learn_a,
                 learn_b,
                 learn_c,
                 learn_d,
                 **kernel_args):
        super().__init__()
        
        self.n_hippos = n_hippos
        self.d_state = d_state   
        self.channels = channels
        self.lr = lr
        
        self.random_a = random_a
        self.random_b = random_b
        self.random_c = random_c
        self.random_d = random_d
        
        self.learn_a = learn_a
        self.learn_b = learn_b
        self.learn_c = learn_c
        self.learn_d = learn_d
        
    def initialize_weights(self):
        raise NotImplementedError
                
    def forward(self, u):
        raise NotImplementedError
        
    def fft_conv(self, u, v):
        L   = u.shape[-1]
        u_f = torch.fft.rfft(u, n=2*L) # (B H L)
        v_f = torch.fft.rfft(v, n=2*L) # (C H L)

        y_f = oe.contract('bhl,chl->bchl', u_f, v_f) 
        y   = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L)
        return y
        
    def fft_conv_d(self, u, v):
        L   = u.shape[-1]
        u_f = torch.fft.rfft(u, n=2*L, dim=2).unsqueeze(-1) # (B H L 1)
        v_f = torch.fft.rfft(v, n=2*L, dim=3).unsqueeze(-1) # (C H D L 1)

        y_f = oe.contract('bhli,chdli->bchld', u_f, v_f) 
        y   = torch.fft.irfft(y_f, n=2*L, dim=3)[:, :, :, :L, :] # (B C H L D)
        return y
    
    def norm(self, x, ord=1):
        # x.shape = C x H x D
        x_norm = torch.linalg.norm(x, ord=ord, dim=2, keepdim=True)
        x = x / x_norm if x_norm[:, 0].item() != 0 else x 
        return x


class CompanionKernel(SpaceTimeKernel):
    def __init__(self, 
                 random_a=True,
                 random_b=True,
                 random_c=True,
                 random_d=True,
                 learn_a=True,
                 learn_b=True,
                 learn_c=True,
                 learn_d=True,
                 skip_connection=True,
                 norm_b=True,
                 norm_c=True,
                 norm_k=True,  # for rollout
                 norm_ord=1,  # 1 or 2
                 **kernel_args):
        super().__init__(random_a=random_a, random_b=random_b, 
                         random_c=random_c, random_d=random_d,
                         learn_a=learn_a, learn_b=learn_b, 
                         learn_c=learn_c, learn_d=learn_d,
                         **kernel_args)
        
        self._fp = (self.channels, self.n_hippos, self.d_state)
        self.skip_connection = skip_connection
        
        try:
            self.a_init = kernel_args['a_init']
        except:
            self.a_init = None
        try:
            self.b_init = kernel_args['b_init']
        except:
            self.b_init = None
        try:
            self.c_init = kernel_args['c_init']
        except:
            self.c_init = None
            
        try:
            self.memory_norm = kernel_args['memory_norm']
        except:
            self.memory_norm = 0
        
        self.initialize_weights()
        self.recurrent = kernel_args['recurrent']
        self.horizon   = kernel_args['horizon']        
        
        if self.memory_norm == 1:
            self.norm_b = norm_b  
            self.norm_c = norm_c  
            self.norm_k = norm_k  # Rollout / control
        else:
            self.norm_b = norm_b
            self.norm_c = norm_c
            self.norm_k = norm_k  # Rollout / control
        
        
        self.norm_ord = norm_ord
        
        self.ssm_grad = True
        self.feedback_grad = True  
        
        self.train_lag = False
        self.train_horizon = False
        
        # For supervising rollout with control
        if self.recurrent:
            self.reference_inputs = None
            self.feedback_inputs = None
            
        self.joint_train = False
        
    def initialize_weights(self):
        # A
        if self.random_a and (self.a_init is None or self.a_init == 'randn'):
            a = torch.randn(*self._fp)
        elif self.a_init == 'basis':
            a = repeat(torch.zeros(self.d_state).float(), 'd -> c h d',
                       c=self.channels, h=self.n_hippos).clone().contiguous()
            a[..., 0] = 1.
            
        elif self.a_init == 'ones':
            a = torch.ones(*self._fp)
        elif self.a_init == 'kaiming':
            a = torch.randn((self.n_hippos, self.channels, self.d_state))
            nn.init.kaiming_uniform_(a, a=math.sqrt(5))  
            a = rearrange(a, 'h c d -> c h d')
        else:
            a = repeat(torch.zeros(self.d_state).float(), 'd -> c h d',
                       c=self.channels, h=self.n_hippos).clone().contiguous()
        
        if self.memory_norm == 1:
            alphas = torch.linspace(0, 10, self.n_hippos)
            norm_factors = (1 + alphas) * torch.linalg.norm(a, ord=1, dim=2)
            a = a / norm_factors[..., None] if norm_factors[:, 0].item() != 0 else a           
        
        self.register("a", a, self.learn_a, lr=self.lr, wd=None)
        self.shift_matrix = torch.zeros(self.channels, 
                                        self.n_hippos, 
                                        self.d_state, 
                                        self.d_state)
        self.shift_matrix[:, :, 1:, :-1] = torch.eye(self.d_state - 1)
        self.a_padding = torch.zeros(*self._fp)
        self.a_padding[:, :, -1] = 1.
        
        # B
        if self.random_b and (self.b_init is None or self.b_init == 'randn'):
            b = torch.randn(*self._fp)            
        elif self.b_init == 'ones':
            b = torch.ones(*self._fp)
        elif self.b_init == 'kaiming':
            b = torch.randn((self.n_hippos, self.channels, self.d_state))
            nn.init.kaiming_uniform_(b, a=math.sqrt(5))  
            b = rearrange(b, 'h c d -> c h d')
        else:
            b    =  torch.zeros(self.d_state).float()
            b[0] = 1.
            b    = repeat(b, 'd -> c h d', 
                          c=self.channels, h=self.n_hippos).clone().contiguous()
        self.register("b", b, self.learn_b, lr=self.lr, wd=None)
        
        # C
        if self.random_c and (self.c_init is None or self.c_init == 'randn'):
            c = torch.randn(*self._fp)
            self.register("c", c, self.learn_c, lr=self.lr, wd=None)
            
            # Take care of feedback as same as c
            k = torch.randn(*self._fp)
            self.register("k", k, True, lr=self.lr, wd=None)
            
        elif self.c_init == 'ones':
            c = torch.ones(*self._fp)
            self.register("c", c, self.learn_c, lr=self.lr, wd=None)
            k = torch.ones(*self._fp)
            self.register("k", k, True, lr=self.lr, wd=None)
            
        elif self.c_init == 'kaiming':
            c = torch.randn((self.n_hippos, self.channels, self.d_state))
            nn.init.kaiming_uniform_(c, a=math.sqrt(5))  
            c = rearrange(c, 'h c d -> c h d')
            self.register("c", c, self.learn_c, lr=self.lr, wd=None)
            
            k = torch.randn((self.n_hippos, self.channels, self.d_state))
            nn.init.kaiming_uniform_(k, a=math.sqrt(5))  
            k = rearrange(k, 'h c d -> c h d')
            self.register("k", k, True, lr=self.lr, wd=None)
        else:
            pass  # Handle specifically with children
        
        # D
        self.d = nn.Parameter(
            torch.randn(self.channels, self.n_hippos)
        )
        self.ssm_params = ['a', 'b', 'c', 'd']
            
        # K (feedback)
        self.feedback_params = ['k']
        
    def train_ssm(self):
        self.ssm_grad = True
        for p in self.ssm_params:
            if getattr(self, f'learn_{p}'):
                getattr(self, p).requires_grad = True
                
    def freeze_ssm(self):
        self.ssm_grad = False
        for p in self.ssm_params:
            getattr(self, p).requires_grad = False
            
    def train_feedback(self):
        self.feedback_grad = True
        self.k.requires_grad = True
            
    def freeze_feedback(self):
        self.feedback_grad = False
        self.k.requires_grad = False
        
    def process_lag(self):
        self.train_lag = True
        self.train_horizon = False
        self.k.requires_grad = True
        
    def process_horizon(self, train_k):
        self.train_lag = False
        self.train_horizon = True
        self.k.requires_grad = True if train_k else False
        # Clear out feedback terms too
        if self.recurrent:
            try:
                self.reference_inputs = self.reference_inputs.cpu()
                self.feedback_inputs = self.feedback_inputs.cpu()
            except AttributeError:
                pass
            self.reference_inputs = None
            self.feedback_inputs = None
            
         
    def companion_power(self, l, c, b, p):
        if self.memory_norm == 1 and c is not None:
            try:
                g = companion_krylov(l, p, b, c)
            except:
                breakpoint()
            return g
        ch, h, d = b.shape 
        # Construct companion matrix
        A = self.shift_matrix.to(p.device) + oe.contract(
            'c h i, c h j -> c h j i',  self.a_padding.to(p.device), p
        )
        # Use repeated squares
        g = krylov(l, A, b, c)
        return g
    
    def companion_power_roll(self, l, c, b, p):
        # Alternative way to compute companion powers
        ch, h, d = b.shape
        g = torch.zeros((ch, h, l)).to(c.device)
        x = b.clone()
        for i in range(l):   
            g[:, :, i] = oe.contract('chd, chd-> ch', c, x) 
            u          = x[:, :, -1].unsqueeze(-1) * p 
            x          = torch.roll(x, shifts=(1), dims=(2)) 
            x[:, :, 0] = 0.
            x = x + u
        return g
    
    def forward(self, u, v=0, c=None):
        if self.recurrent and self.joint_train:
            return self._forward_joint(u, v, c)
        else:  
            return self._forward(u, v, c)
            
    def _forward(self, u, v=0, c=None):  
        # v is dummy extra input; can be used for mean residual fitting
        # -> not used right now; would need to have y + v
        
        if self.recurrent:  
            # For now, we will use two passes to handle lag and horizon terms
            # First, for the lag terms, we collect the last layer's inputs and Kx
            # - We do this via forward hooks, (ignoring the actual model outputs)
            # - We then compute the RMSE off of these
            # Then, for the horizon terms, we freeze the last layer's K, 
            # - We compute the outputs using y = C(A + BK)^h x[L] for h in 1, ..., H
            # - We compute the RMSE with the true outputs
            if self.train_lag:
                # Kx
                self.reference_inputs = u[:, 1:, :][:, self.d_state:, :].flatten()
                y = self.convolve(u, self.k, k=True)
                self.feedback_inputs = rearrange(y, 'b c h l -> b (c h) l')[:, :-1, :][:, self.d_state:, :].flatten()
            else:
        
                if self.feedback_grad or self.train_horizon:
                    # We currently have a hack where v is a tuple: (lag, horizon, hidden_state start)  
                    # - More generally, v is a tuple of (starting x[n], and n_steps)
                    if v == 0:
                        n_steps = self.horizon; n_hidden_state = u.shape[-1]
                        concatenate_input = 1
                    else:
                        lag, horizon, n_hidden_state = v  # can set horizon to be 0 if we just want to do output over lag terms
                        n_steps = lag - n_hidden_state + horizon 
                        concatenate_input = 0

                    y = self.rollout(u, n_steps,  # self.horizon, 
                                     n_hidden_state,
                                     v=concatenate_input)  
                else:
                    y = self.convolve(u, c)
        else:
            y = self.convolve(u, c)
                
        return rearrange(y, 'b c h l -> b (c h) l'), v
    
    def _forward_joint(self, u, v=0, c=None):
        if self.recurrent:
            if self.train_lag:
                self.reference_inputs = u[:, 1:, :][:, self.d_state:, :].flatten()
                _y = rearrange(self.convolve(u, self.k, k=True), 
                               'b c h l -> b (c h) l')[:, :-1, :][:, self.d_state:, :]
                self.feedback_inputs = _y.flatten()
            
            lag, horizon, n_hidden_state = v
            n_steps = lag - n_hidden_state + horizon 
            concatenate_input = 0
            y = self.rollout(u, n_steps, n_hidden_state, v=concatenate_input)
        else:
            y = self.convolve(u, c)
        return rearrange(y, 'b c h l -> b (c h) l'), v
            
    
    def get_transfer_func_from_companion(self, hippo_ix, n_steps):
        """
        Function to output transfer function
        """
        u = torch.zeros(1, self.n_hippos, n_steps)
        zk = torch.arange(n_steps).view(1,1,-1,1)

        a = self.norm(self.a, ord=self.norm_ord)  
        if self.recurrent:
            b = self.norm(self.b, ord=self.norm_ord) if self.norm_b else self.b
        with torch.no_grad():
            f = self.companion_power_roll(
                zk.shape[2], self.c, self.b, a).cpu()
        return f[:, hippo_ix, :] 
    
    def get_hidden_state(self, l, a, b, u):
        # a, b, c.shape is C x H x D
        a = self.norm(a, ord=self.norm_ord)
        f = self.companion_power(l, None, b, a) # C x H x D x L
        x = self.fft_conv_d(u, f)  # B x C x H x L x D
        return x
    
    
    def convolve(self, u, c, k=False):
        l  = u.size(-1)
        T = 1/(l-1) 
        zk = torch.arange(u.size(-1), device=u.device).view(1,1,-1,1)
        # Normalize for stability
        a = self.norm(self.a, ord=self.norm_ord)
        c = self.c if c is None else c
        f = self.companion_power(zk.shape[2], c, self.b, a).to(u.device)
        try:
            y = self.fft_conv(u, f) # -> outputs L - horizon
        except:
            breakpoint()
        
        if k:  # "c" is K
            return y

        if self.skip_connection and self.learn_d:
            y = y + oe.contract('bhl,ch->bchl', u, self.d)  
        elif self.skip_connection:
            y = y + u.unsqueeze(1)  # Fixed identity
        return y
    
    
    def rollout(self, u, n_steps, n_hidden_state, v=0):  # v not used
        # Get hidden state = x[n_hidden_state]
        # Assumes that u has no ending 0's, e.g. u[-args.horizon] was passed in
        x = self.get_hidden_state(n_hidden_state, self.a, self.b, u)   # B x C x H x L x D

        # Get C * (A + BC)^n_rollout * x[n_start]
        # 1) Construct A
        a = self.norm(self.a, ord=self.norm_ord) 
        A = self.shift_matrix.to(a.device) + oe.contract(
            'c h i, c h j -> c h j i',  self.a_padding.to(a.device), a
        ) 
        # 2) Construct BK -> explodes without normalization?
        b = self.norm(self.b, ord=self.norm_ord) if self.norm_b else self.b
        c = self.norm(self.c, ord=self.norm_ord) if self.norm_c else self.c
        k = self.norm(self.k, ord=self.norm_ord) if self.norm_k else self.k
        BK = oe.contract('c h i, c h j -> c h i j',  b, k)  # .double()
        # 3rd compute the Krylov matrix (cb, cAb, cA^2b, ...) using the squaring trick
        # -> In this case, C(A + BK)^0 x[n], C(A + BK)^1 x[n], ... C(A + BK)^{H - 1} x[n]
        y = krylov(n_steps, A + BK, x[:, :, :, n_hidden_state - 1, :], c)  # Last one
        
        if v == 1:
            y = torch.concat([u.unsqueeze(1), y], dim=3) # B x C x H x L   
        return y
    
class ShiftKernel(CompanionKernel):
    def __init__(self, 
                 random_a=False,
                 random_b=False,
                 random_c=True,
                 random_d=False,
                 learn_a=False,
                 learn_b=False, # Shift by default only learns c (companion with large d_state acts as a shift that learns b and c)
                 learn_c=True,                 
                 learn_d=False,
                 skip_connection=False,
                 **kernel_args):
        super().__init__(random_a=random_a, random_b=random_b, 
                         random_c=random_c, random_d=random_d,
                         learn_a=learn_a, learn_b=learn_b, 
                         learn_c=learn_c, learn_d=learn_d,
                         skip_connection=skip_connection, **kernel_args)
        
    def initialize_weights(self):
        super().initialize_weights()
        a = repeat(torch.zeros(self.d_state).float(), 'd -> c h d',
                   c=self.channels, h=self.n_hippos).clone().contiguous()
        self.register("a", a, self.learn_a, lr=self.lr, wd=None)
    
    def _round_divisors(self, x):
        # For sampling different orders
        # - Don't include 1
        return np.unique([x // n for n in range(1, x)])[1:]
    
    
class ARKernel(ShiftKernel):  
    """
    Basically the same as Shift, but don't learn b by default
    - Also optionally set d_state = horizon, so we can learn 
      multiple different ordered AR processes per hippo
      by masking out elements in the C vector
    """
    def __init__(self, 
                 ar_d_state=None, 
                 learn_b=False, 
                 c_mask=None, 
                 **kernel_args):
        if ar_d_state is not None:
            kernel_args['d_state'] = ar_d_state
        self.c_mask = c_mask    
        super().__init__(learn_b=learn_b, **kernel_args)
        
    def initialize_weights(self):
        super().initialize_weights()
        # Sample different orders for diff hippos
        if self.c_mask is None and self.n_hippos > 1:  
            self.c_mask = torch.zeros(*self._fp).float()
        
            # Heuristic right now for sampling different possible orders
            self.orders = torch.randint(low=2, high=self.d_state - 1, 
                                        size=(1, self.n_hippos))[0]
            self.orders = self.orders[
                torch.randint(len(self.orders), size=(1, self.n_hippos))[0]
            ]
            for ix, order in enumerate(self.orders):
                self.c_mask[:, ix, :order] = 1. 
                
            self.inherit_c_mask = False
        # If only 1 hippo then keep AR order == specified d_state
        elif self.c_mask is None and self.n_hippos == 1:   
            self.c_mask = torch.ones(*self._fp).float()  
            self.inherit_c_mask = False
        else:
            self.inherit_c_mask = True
            
        # Re-initialize c based on mask
        c = oe.contract('chd, chd -> chd', self.c, self.c_mask.to(self.c.device))
        self.register("c", c, self.learn_c, lr=self.lr, wd=None)
        
    def forward(self, u, v=0):
        # Apply C mask -> so we can update weights but keep the MA order
        if self.inherit_c_mask:
            c = oe.contract('chd, chd -> chd', self.c, self.c_mask.to(self.c.device))
            y = super().forward(u, v=v, c=c)
        else:
            y = super().forward(u, v=v)
        return y
        
        
class MAZeroKernel(ShiftKernel):
    def __init__(self, **kernel_args):
        # Account for error_t + c_1 * error_{t+1} + ... c_q * error_{t+q}
        self.d_state += 1   # maybe shouldn't
        self._fp = (self.channels, self.n_hippos, self.d_state)
        super().__init__(**kernel_args)
        
    def initialize_weights(self):
        super().initialize_weights()
        c = torch.randn(*self._fp)
        c[:, :, 0] = 1  # c_0 = 1
        self.register("c", c, self.learn_c, lr=self.lr, wd=None)
        
        
class MAKernel(ShiftKernel):
    def __init__(self, 
                 ma_d_state=None,  # should be the same as lag
                 learn_b=False, 
                 **kernel_args):
        # Override the specified d_state
        # -> So we can have multiple diff-ordered MA models  
        #    by masking out some elements of the C vector
        if ma_d_state is not None:
            kernel_args['d_state'] = ma_d_state
        super().__init__(learn_b=learn_b, **kernel_args)
        
    def initialize_weights(self):
        super().initialize_weights()
        
        # Moving average window kernels
        c = torch.zeros(*self._fp).float()
        self.learn_c = True
        self.c_mask = torch.zeros(*self._fp).float()
        # Optional output masking
        self.y_mask = torch.zeros(*self._fp).float()
        
        # Heuristic right now for sampling different possible orders
        self.orders = torch.randint(low=2, high=self.d_state - 1, 
                                    size=(1, self.n_hippos))[0]
        self.orders = self.orders[
            torch.randint(len(self.orders), size=(1, self.n_hippos))[0]
        ]
        
        for ix, order in enumerate(self.orders):
            c[:, ix, :order] = 1. / order
            self.c_mask[:, ix, :order] = 1.
            self.y_mask[:, ix, -order:] = 1.

        self.register("c", c, self.learn_c, lr=self.lr, wd=None)
        
    def forward(self, u, v=0):
        # Compute means
        # Apply C mask -> so we can update weights but keep the MA order
        c = oe.contract('chd, chd -> chd', self.c, self.c_mask.to(self.c.device))
        m, _ = super().forward(u, v, c=c)  # B x H x L, assume channels = 1
        # only keep last one -> rolling average currently not used
        m = m[:, :, -1:]
        
        # Compute residuals
        y = u - m  
        # This will include inputs we don't want to model for that hippo, so
        # we get around this by passing the c_mask to the next layer AR kernels
        # Optionally, (for testing) we can also zero-out terms not in the MA order
        
        # mask out these or not
        # y = oe.contract('bhl, chl -> bchl', y, self.y_mask.to(y.device))        
        # y = rearrange(y, 'b c h l -> b (c h) l')
        return y, m
    
    
    def _round_divisors(self, x):
        # For sampling moving average windows
        # - Don't include 1
        return np.unique([x // n for n in range(1, x)])[1:]
        
        
        
class CompleteMAKernel(nn.Module):
    """
    Consists of 2 "layers" of kernels: 
    * 1st computes input residuals based on moving averages
    # 2nd fits coefficients to residuals to predict output 
    """
    def __init__(self, n_hippos, ma_d_state, **kernel_args):
        kernel_args['n_hippos'] = n_hippos
        kernel_args['d_state'] = ma_d_state
        super().__init__()  # **kernel_args)
        self.n_hippos = n_hippos
        self.d_state = ma_d_state
        
        self.error_kernel   = MAKernel(ma_d_state=ma_d_state,
                                       **kernel_args)
        self.fitting_kernel = ARKernel(ar_d_state=ma_d_state, 
                                       learn_b=False, 
                                       c_mask=self.error_kernel.c_mask,
                                       **kernel_args)  # doesn't account for mu + error_t
        
    def forward(self, u, v=0):  
        """
        Compute y_t = mu + error_t + theta_1 * error_{t-1} + ... theta_q * error_{q-1}
        -> Ignores the first error_t
        -> Also may have some 0s
        """
        # u is B x H x L, assumes channels = 1
        u, m = self.error_kernel(u, v=0)  # Get residuals and means
        y = self.fitting_kernel(u)[0] + m  # broadcast across
        return y, v
        

class DifferencingKernel(ShiftKernel):
    def __init__(self,
                 diff_d_state,
                 num_orders=None,  # Hardcoded for now
                 explicit_order=None,
                 random_a=False,
                 random_b=False,
                 random_c=False,
                 learn_a=False,
                 learn_b=False,
                 learn_c=False,
                 skip_connection=False,                 
                 **kernel_args):
        kernel_args['d_state'] = diff_d_state
        super().__init__(random_a, random_b, random_c,
                         learn_a, learn_b, learn_c, 
                         skip_connection, **kernel_args)
        self._fp = (self.channels, self.n_hippos, self.d_state)
        self.num_orders = 4 if num_orders is None else num_orders
        self.explicit_order = explicit_order
        self._initialize_weights()
        
    def _get_differencing_c(self, c, order, hippo_ix):
        # Hardcoded for now, but this is just binomial coefficients (with negatives)
        if self.explicit_order is not None:
            order = self.explicit_order
        if order == 0:  
            c[:, hippo_ix, 0] = 1.
        elif order == 1: # should test  -> y_t - y_{t - 1}
            c[:, hippo_ix, :2] += torch.tensor([1, -1]).float()
        elif order == 2:
            c[:, hippo_ix, :3] += torch.tensor([1, -2, 1]).float()
        elif order == 3:
            c[:, hippo_ix, :4] += torch.tensor([1, -3, 3, -1]).float()
        else:
            raise NotImplementedError
        return c
        
    def _initialize_weights(self):
        super().initialize_weights()
        a = repeat(torch.zeros(self.d_state).float(), 'd -> c h d',
                   c=self.channels, h=self.n_hippos).clone().contiguous()
        self.register("a", a, self.learn_a, lr=self.lr, wd=None)
        
        c = torch.zeros(*self._fp).float()
        for i in range(c.shape[1]):
            self._get_differencing_c(c, i % self.num_orders, i)
        self.register("c", c, self.learn_c, lr=self.lr, wd=None)
  

class MAErrorKernel(ARKernel):
    def __init__(self, ma_d_state,  # should be the same as horizon
                 **kernel_args):
        self.d_state = ma_d_state
        self.skip_connection = False
        super().__init__(**kernel_args)
        
    def initialize_weights(self):
        super().initialize_weights()
        
        # Moving average window kernels
        c = torch.zeros(*self._fp).float()
        c[:, :, 0] = 1.
            
        # Low is a heuristic for now
        self.mean_windows = torch.randint(low=4, high=self.d_state,
                                          size=(1, self.n_hippos))[0]
        # Set c values s.t. shift matrix computes residuals from moving average
        # i.e., y - y_mean
        for ix, mean_window in enumerate(self.mean_windows):  # could be faster with matrix addition
            c[:, ix, :mean_window] -= 1. / mean_window

        self.learn_c = False
        self.register("c", c, self.learn_c, lr=self.lr, wd=None)
        
    
class ResidualKernel(ARKernel):
    """
    Combined kernel for differencing and moving average residual errors
    """
    # ma_d_state should be the same as horizon
    def __init__(self, n_diff, n_ma_error, ma_d_state, num_orders=4, **kernel_args):
        self.d_state = ma_d_state
        self.n_diff = n_diff  # Number of differencing kernels
        self.n_ma_error = n_ma_error  # Number of MA error kernels
        self.n_hippos = self.n_diff + self.n_ma_error
        self.num_orders = num_orders
        self.skip_connection = False  
        kernel_args['n_hippos'] = self.n_diff + self.n_ma_error
        
        try:
            learn_c = kernel_args['train']
            print(kernel_args['train'])
        except Exception as e:
            print(e)
            learn_c = False

        super().__init__(learn_c=learn_c, **kernel_args)
        
    def initialize_weights(self):
        # Initialize residual weights. 
        # May want to refactor to the differencing and MA error weight initialization to their own methods
        super().initialize_weights()
        hippo_ix = np.arange(self.n_hippos)
        
        c = torch.zeros(*self._fp).float()
        c[:, :, 0] = 1.
        
        # Initialize differencing weights
        # Split orders equally across hippos (as much as possible)
        n_diff = (self.n_diff - (self.n_diff % self.num_orders))
        hippo_ix_diff = np.random.choice(hippo_ix, size=n_diff, replace=False)
        
        # Shuffle hippos to assign to differencing kernels
        shuffle_ix = np.arange(n_diff)
        np.random.shuffle(shuffle_ix)
        
        for order_ix, _hippo_ix in enumerate(np.split(hippo_ix_diff[shuffle_ix],  self.num_orders)):
            self._get_differencing_c(c, order_ix, _hippo_ix)
        
        # Initialize moving average error weights
        hippo_ix_ma_error = hippo_ix[~np.isin(hippo_ix, hippo_ix_diff)]
        
        # Low is a heuristic for now
        self.mean_windows = torch.randint(low=4, high=self.d_state,
                                          size=(1, self.n_hippos - n_diff))[0]
        # Set c values s.t. shift matrix computes residuals from moving average
        # i.e., y - y_mean
        for ix, mean_window in enumerate(self.mean_windows):  # could be faster with matrix addition
            c[:, hippo_ix_ma_error[ix], :mean_window] -= 1. / mean_window
        
        print(f'Learning c: {self.learn_c}')
        self.register("c", c, self.learn_c, lr=self.lr, wd=None)
        
        
    def _get_differencing_c(self, c, order, hippo_ix):
        # Hardcoded for now, but this is just binomial coefficients (with negatives)
        if order == 0:  
            c[:, hippo_ix, 0] = 1.
        elif order == 1: # should test  -> y_t - y_{t - 1}
            c[:, hippo_ix, :2] = torch.tensor([1, -1]).float()
        elif order == 2:
            c[:, hippo_ix, :3] = torch.tensor([1, -2, 1]).float()
        elif order == 3:
            c[:, hippo_ix, :4] = torch.tensor([1, -3, 3, -1]).float()
        else:
            raise NotImplementedError
        return c
