"""
Data input / output transforms  
- General workflow is: for each data batch:  
  1) Apply input transform to data  
  2) Feed transformed data to model    
  3) Collect model output
  4) Apply output transform (inverse input transform) to model output
  5) Use transformed output as final output
"""  
import torch.nn as nn


def get_data_transforms(method, horizon):
    if method == 'mean':
        input_transform = MeanAffineTransform(method=method, 
                                              horizon=horizon)
        output_transform = InverseAffineTransform(method=method, 
                                                  transform=input_transform)
    elif method == 'last':
        input_transform = LastAffineTransform(method=method, 
                                              horizon=horizon)
        output_transform = InverseAffineTransform(method=method, 
                                                  transform=input_transform)
    elif method == 'ts_mean':
        input_transform = TSNormMeanTransform(method=method, 
                                              horizon=horizon)
        output_transform = InverseAffineTransform(method=method, 
                                                  transform=input_transform)
    elif method == 'ts_last':
        input_transform = TSNormLastTransform(method=method, 
                                              horizon=horizon)
        output_transform = InverseAffineTransform(method=method, 
                                                  transform=input_transform)    
    elif method == 'none':
        input_transform = lambda x: x
        output_transform = lambda x: x
    else:
        raise NotImplementedError
        
    return input_transform, output_transform


class AffineTransform(nn.Module):
    def __init__(self, method, horizon=0):
        """
        Transform data: f(x) = ax - b  
        - Subtract b to zero-center if method == 'mean'
        """
        super().__init__()
        self.method = method
        self.horizon = horizon
        
    def forward(self, x):
        # Assume x.shape is B x L x D
        raise NotImplementedError
    
    
class InverseAffineTransform(nn.Module):
    def __init__(self, method, transform):
        super().__init__()
        self.method = method
        self.transform = transform  # AffineTransform object
        
        # x = rearrange(x,'b l d -> (b d) l').unsqueeze(-1)
        
    def forward(self, x):
        try:
            return x / self.transform.a + self.transform.b.to(x.device)
        except:
            breakpoint()
            try:
                return x / self.transform.a.to(x.device) + self.transform.b
            except:
                breakpoint()
        
        
class MeanAffineTransform(AffineTransform):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def forward(self, x):
        self.a = 1.
        # self.b = x.mean(dim=1)[:, None, :]
        # ideally
        self.b = x[:, :-self.horizon, :].mean(dim=1)[:, None, :]
        return self.a * x - self.b
    
    
class LastAffineTransform(AffineTransform):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def forward(self, x):
        self.a = 1.
        self.b = x[:, -self.horizon - 1, :][:, None, :]
        return self.a * x - self.b
    
    
class TSNormMeanTransform(AffineTransform):
    """
    The old TSNormalization code in hippo repo
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def forward(self, x):
        self.a = x.abs()[:, :-self.horizon, :].mean(dim=1)[:, None, :]
        self.b = 0.
        return self.a * x - self.b
    
    
class TSNormLastTransform(AffineTransform):
    """
    The old TSNormalization code in hippo repo
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def forward(self, x):
        self.a = x.abs()[:, -self.horizon - 1, :][:, None, :]
        self.b = 0.
        return self.a * x - self.b
        