import torch
import torch.nn as nn
import math
from typing import Callable, Tuple


def do_nothing(x, mode=None):
    return x

def bipartite_soft_matching(
    metric: torch.Tensor,
    r: int,
    class_token: bool = False,
    distill_token: bool = False,
) -> Tuple[Callable, Callable]:
    """
    SOURCE: https://github.com/facebookresearch/ToMe
    Applies ToMe with a balanced matching set (50%, 50%).

    Input size is [batch, tokens, channels].
    r indicates the number of tokens to remove (max 50% of tokens).

    Extra args:
     - class_token: Whether or not there's a class token.
     - distill_token: Whether or not there's also a distillation token.

    When enabled, the class token and distillation tokens won't get merged.
    """
    protected = 0
    if class_token:
        protected += 1
    if distill_token:
        protected += 1

    # We can only reduce by a maximum of 50% tokens
    t = metric.shape[1]
    r = min(r, (t - protected) // 2)

    if r <= 0:
        return do_nothing, do_nothing

    with torch.no_grad():
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric[..., ::2, :], metric[..., 1::2, :]
        scores = a @ b.transpose(-1, -2)

        if class_token:
            scores[..., 0, :] = -math.inf
        if distill_token:
            scores[..., :, 0] = -math.inf

        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
        dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)

        if class_token:
            # Sort to ensure the class token is at the start
            unm_idx = unm_idx.sort(dim=1)[0]

    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        src, dst = x[..., ::2, :], x[..., 1::2, :]
        n, t1, c = src.shape
        unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
        src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
        dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

        if distill_token:
            return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
        else:
            return torch.cat([unm, dst], dim=1)

    def unmerge(x: torch.Tensor) -> torch.Tensor:
        unm_len = unm_idx.shape[1]
        unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
        n, _, c = unm.shape

        src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))

        out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)

        out[..., 1::2, :] = dst
        out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm)
        out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src)

        return out

    return merge, unmerge

def basis_coeffs_initialisation(times, values, frequencies, lamb=1):
    """
    Initialise basis coeffs using either (1) a linalg solver (when there are enough examples)
    or (2) a linalg lstsq solver (fewer examples than coeffs)

    """
    sins = torch.sin(2 * torch.pi * frequencies * times.unsqueeze(-1)) 
    coss = torch.cos(2 * torch.pi * frequencies * times.unsqueeze(-1))

    X = torch.concatenate((sins, coss), -1)
    eye = torch.eye(X.size(-1), device=times.device)
    eye[0,0] = 0
    reg_term = torch.repeat_interleave(lamb*eye.unsqueeze(0), repeats=times.size(0), dim=0)
    if X.size(1) < frequencies.size(2):
        parameters = torch.linalg.lstsq(torch.bmm(X.transpose(1,2), X)+reg_term, torch.bmm(X.transpose(1,2), values.unsqueeze(-1)).sum(-1)).solution
    else:
        parameters = torch.linalg.solve(torch.bmm(X.transpose(1,2), X)+reg_term, torch.bmm(X.transpose(1,2), values.unsqueeze(-1)).sum(-1))
    sin_coeffs = parameters[:,:frequencies.size(-1)]
    cos_coeffs = parameters[:,frequencies.size(-1):frequencies.size(-1)*2]    
    return sin_coeffs, cos_coeffs


class D3A2M(nn.Module):
    """
    The DAM wrapper to build the embedders, backbone, and collapsors. 
    The forward method estimates the basis coeffs and the forecast
    method assess those for any given times. 
    Args:
        d_model: latent width of model; default=256
        d_ff: width of hidden layer for each feed-forward block; 
              default=256
        n_layers: number of DAM layers; default=4 or 12
        n_tome: number of TV-tokens after iterative ToME; default=250
        n_heads: number of MHSA and cross attention heads; default=4
        dropout: default=0
    """
    def __init__(self, d_model, d_ff, n_layers, n_heads, n_tome, 
                 dropout=0, base_frequency=1):
        super(D3A2M, self).__init__()
        # The DAM's base frequency is 1 day (86400 seconds)
        self.register_buffer('time_scaling', 
                             torch.Tensor([86400/base_frequency]))
        # Basis frequencies, shape [1, 1, 219]
        self.register_buffer('frequencies', 
            torch.concatenate((1440/torch.arange(5, 61, 10), 
                               24/(torch.arange(1, 48, 0.5)), 
                               1/torch.arange(2, 28, 0.5), 
                               1/torch.arange(28, 52*7, 7), 
                               1/torch.arange(52*7, 52*7*10+1, 26*7))
                               ).unsqueeze(0).unsqueeze(0))
        # Embeddings for TV-tokens, B-tokens, and affine
        self.temporal_embedding = nn.Linear(self.frequencies.size(-1), 
                                            d_model)
        self.value_embedding = nn.Linear(1, d_model)   
        self.btoken_period_embedder = nn.Linear(2, d_model)
        self.btoken_coeffs_embedder = nn.Linear(2, d_model)
        self.affine_embedding = nn.Linear(50, d_model)          
        # Backbone model
        self.backbone = TransformerBackbone(d_model=d_model,
                                            d_freq=219,
                                            d_ff=d_ff,
                                            n_layers=n_layers,
                                            n_tome=n_tome,
                                            n_heads=n_heads,
                                            dropout=dropout)
        # Collapse B-tokens into 2 coeffs each
        self.basis_collapsor = nn.Linear(d_model, 2) 
        # Affine collapse into offset and scale
        self.affine_collapser = nn.Linear(d_model, 2)  

    def forecast(self, times):
        """
        The actual forecasting method (after the DAM forward pass)
        args:
            times: forecast times (past or future) in seconds
        """
        pred = 0  # Set this up for the forecast
        times = times.unsqueeze(-1)/self.time_scaling  # days
        cos_coeffs = self.basis_coeffs[:,:,0].unsqueeze(1)
        sin_coeffs = self.basis_coeffs[:,:,1].unsqueeze(1)
        sins = sin_coeffs * torch.sin(2*torch.pi*self.frequencies*times)
        coss = cos_coeffs * torch.cos(2*torch.pi*self.frequencies*times)  
        pred += sins.sum(-1)
        pred += coss.sum(-1)
        # Reverse the robust standardisation and apply affine params:
        med = self.normalisation_params[:,:,0] 
        iqr = self.normalisation_params[:,:,1] 
        affine_offset = self.normalisation_params[:,:,2] 
        affine_scale = self.normalisation_params[:,:,3] 
        pred = iqr * ((pred - affine_offset)/affine_scale) + med
        return pred

    def forward(self, times, values):
        """
        The DAM wrapper expects times in seconds, where 0 is 'now', 
        the future is positive time and the past is negative time. 
        Values are always univariate.
        args:
            times: time in seconds; shape=[M, C]
            values: values at times; shape=[M, C]
        M: minibatch size
        C: context size (e.g., 540)
        """
        times = times/self.time_scaling  # Convert time to days
        with torch.no_grad():
            # Robust standardisation, unlike RevInv
            med = torch.median(values, axis=1, keepdim=True)[0]
            iqr = torch.quantile(values, q=0.75, axis=1, keepdim=True)-\
                torch.quantile(values, q=0.25, axis=1, keepdim=True)
            iqr[iqr<1e-6] = 1e-6
            values = (values - med)/iqr 
            sin_coeffs, cos_coeffs = basis_coeffs_initialisation(times, 
                                                  values, 
                                                  self.frequencies)
        # Embed times for TV-tokens at the scales of the frequencies
        temporal_embedding = self.temporal_embedding(
            torch.sin(2*torch.pi*self.frequencies * times.unsqueeze(-1)))
        value_embedding = self.value_embedding(values.unsqueeze(-1))
        tv_tokens = value_embedding + temporal_embedding
        # Embed frequencies as periods similarly
        periods_embed = self.btoken_period_embedder(
            torch.concatenate((
                torch.sin(2*torch.pi/self.frequencies).transpose(1,2), 
                torch.cos(2*torch.pi/self.frequencies).transpose(1,2)), 
                -1))        
        # arcsinh damps absolutely high coeffs
        b_tokens = self.btoken_coeffs_embedder(
            torch.arcsinh(0.1*\
                torch.stack((cos_coeffs, sin_coeffs), -1))/0.1)\
                + periods_embed
        # Quantiles for affine token
        quantiles = torch.quantile(values,
                        q=torch.linspace(0., 1, 50+2)[1:-1], dim=1)
        affine_token = self.affine_embedding(quantiles.T.unsqueeze(1))
        # Backbone        
        b_tokens_out, affine_token_out = self.backbone(tv_tokens, 
                                                       b_tokens, 
                                                       affine_token)
        # Collapse into usable function space
        affine_collapsed = self.affine_collapser(affine_token_out)
        # Set coeffs and normalisation params 
        # as attributes for forecasting step (hereafterr)
        self.normalisation_params = torch.concatenate(
                                        (med.unsqueeze(-1), 
                                         iqr.unsqueeze(-1), 
                                         affine_collapsed), -1)
        self.basis_coeffs = self.basis_collapsor(b_tokens_out) 
    
class TransformerBackbone(nn.Module):
    r"""
    Backbone for the DAM, taking in embedded TV-tokens, B-tokens, 
    and affine token.
    Args:
        d_model: width of model; default=256
        d_freq: number of frequencies and B-tokens; default=219
        d_ff: width of hidden layer for each feed-forward block; 
              default=256
        n_layers: number of DAM layers; default=4 or 12
        n_tome: number of TV-tokens after iterative ToME; default=250
        n_heads: number of MHSA and cross attention heads; default=4
        dropout: default=0
    """
    def __init__(self,
                 d_model,
                 d_freq,
                 d_ff,
                 n_layers,
                 n_tome,
                 n_heads,
                 dropout=0,
                 ):
        super(TransformerBackbone, self).__init__()
        self.n_layers = n_layers
        self.n_tome = n_tome
        self.encoder_layers = nn.ModuleList([
            DAMLayer(d_model, d_freq, d_ff, n_heads, dropout) \
                                     for _ in range(n_layers)
        ])
    def forward(self, tv_tokens, p_tokens, affine_token):
        r"""
        Logic in this wrapper includes making sure ToME reduces by 
        the right amount with each layer.
        Args:
        """        
        n_tv_tokens = tv_tokens.size(1)
        # Reduce TV-tokens iteratively reduce at each layer using ToME
        r = int(round((n_tv_tokens-self.n_tome)/self.n_layers)) 
        for li in range(self.n_layers):
            if li == self.n_layers - 1: 
                r = (p_tokens.size(1)) - self.n_tome  # rounding
            tv_tokens, p_tokens, affine_token = \
                self.encoder_layers[li](tv_tokens, p_tokens, affine_token, r)
        return p_tokens, affine_token
    
class DAMLayer(nn.Module):
    def __init__(self,
                 d_model,
                 d_freq,
                 d_ff,
                 n_heads,
                 dropout=0,                 
                 ):
        super(DAMLayer, self).__init__()
        self.mhsa_tv = nn.MultiheadAttention(d_model, n_heads, dropout, 
                                             batch_first=True,
                                             add_zero_attn=True)
        self.cross_attention = nn.MultiheadAttention(d_model, n_heads, 
                                             dropout, batch_first=True, 
                                             add_zero_attn=True)
        self.feed_forward_tv = nn.Sequential(nn.Linear(d_model, d_ff), 
                                             nn.GELU(), 
                                             nn.Linear(d_ff, d_model), 
                                             nn.Dropout(dropout))
        self.feed_forward_b = nn.Sequential(nn.Linear(d_model, d_ff), 
                                            nn.GELU(), 
                                            nn.Linear(d_ff, d_model), 
                                            nn.Dropout(dropout))
        self.feed_forward_b_cross = nn.Sequential(
                                            nn.Linear(d_freq, d_freq*2), 
                                            nn.GELU(), 
                                            nn.Linear(d_freq*2, d_freq), 
                                            nn.Dropout(dropout))
        self.feed_forward_affine = nn.Sequential(nn.Linear(d_model, d_ff), 
                                                 nn.GELU(), 
                                                 nn.Linear(d_ff, d_model), 
                                                 nn.Dropout(dropout))        
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.layernorm3 = nn.LayerNorm(d_model)
        self.layernorm4 = nn.LayerNorm(d_model)
        self.layernorm5 = nn.LayerNorm(d_model)
        self.layernorm6 = nn.LayerNorm(d_model)
        self.layernorm7 = nn.LayerNorm(d_model)

    def forward(self, tv_tokens, b_tokens, affine_token, r):
        r"""
        Concatenate TV-tokens and affine token (in additional_p_tokens)
        to compute attention, but split this up again for ToME
        Args:
            tv_tokens: Time-value tokens; shape=[M, Ntv, D]
            b_tokens: B-tokens for basis coeffs; shape=[M, Nb, D]
            affine_token: for affine adjustment; shape=[M, 1, D]
            r: ToME reduction amount; scalar.
        M: minibatch size
        D: model dimension
        """        
        tokens = torch.concatenate((affine_token, tv_tokens), 1)
        attn_output_tv, _  = self.mhsa_tv(tokens, 
                                          tokens, 
                                          tokens)
        attn_output_affine = attn_output_tv[:,:affine_token.size(1)]
        attn_output_tv = attn_output_tv[:,affine_token.size(1):]
        # Bipartite soft matching from ToME 
        # https://github.com/facebookresearch/ToMe
        merge_method, _ = bipartite_soft_matching(attn_output_tv, r)        
        tv_tokens = merge_method(tv_tokens)  # Merged
        attn_output_tv = merge_method(attn_output_tv)  # Merged
        tv_tokens = self.layernorm1(tv_tokens + attn_output_tv)  
        tv_tokens = self.layernorm2(
                        self.feed_forward_tv(tv_tokens)+tv_tokens)  # FF
        affine_token = self.layernorm3(affine_token + attn_output_affine)
        affine_token = self.layernorm4(
                        self.feed_forward_affine(affine_token)+\
                                affine_token)
        # Transfer information from TV-toekns into B-tokens
        kv_tokens = torch.concatenate((affine_token, tv_tokens), 1)
        cross_attn, _ = self.cross_attention(b_tokens, 
                                             kv_tokens, 
                                             kv_tokens)
        b_tokens = self.layernorm5(b_tokens + cross_attn)
        b_tokens = self.layernorm6(
                    self.feed_forward_b(b_tokens) + b_tokens)  # FF
        # Process ACROSS B-tokens
        b_tokens_T = b_tokens.transpose(1,2)
        b_tokens = self.layernorm7(
            self.feed_forward_b_cross(b_tokens_T).transpose(1, 2)\
                                    + b_tokens)        
        return tv_tokens, b_tokens, affine_token

if __name__=='__main__':
    dam = D3A2M(256, 256, 4, 4, 250)
    # Dummy data and times would be sampled from
    # A long tail distribution and are not necessarily
    # sequential.
    dummy_values = torch.zeros((32, 540)).float()  # shape: [M, C]
    # -10 days until before 'now':
    dummy_times = torch.zeros((32, 540)).random_(-864000, -1).float()  
    # Estimate the basis coeffs
    dam(dummy_times, dummy_values)
    # Times can be past or future
    dummy_forecast_times = torch.zeros((32, 540)).random_(-864000, 
            864000).float()  # anywhen
    forecast = dam.forecast(dummy_forecast_times)
    assert forecast.shape == dummy_forecast_times.shape