__all__ = ['head']


import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..models.DLinear import DLinear
from ..models.RevIN import RevIN

class head(nn.Module):
    
    def __init__(self, c_in, c_out, context_window, target_window, head_type, head_dropout, mlp_dim_1=1024):
        super().__init__()
        self.context_window = context_window
        self.target_window = target_window
        self.head_type = head_type
        if head_type == 'flatten': 
            self.head = Flatten_Head(c_out, target_window, nf = ((c_in-1) * target_window + context_window), head_dropout=head_dropout)
        elif head_type == 'flatten_mlp': 
            self.head = Flatten_mlp_Head(c_out, mlp_dim_1, target_window, nf = ((c_in-1) * target_window + context_window), head_dropout=head_dropout)
        if head_type == 'flatten_full': 
            self.head = Flatten_Head(c_out, mlp_dim_1, target_window, nf = ((c_in-1) * target_window + c_in*context_window), head_dropout=head_dropout)
        elif head_type == 'linear': 
            self.head = Linear_Head(c_in-1, c_out, head_dropout=head_dropout)
        elif head_type == 'mlp': 
            self.head = MLP_Head(c_in-1, c_out, mlp_dim_1, head_dropout=head_dropout)

    def forward(self, z):        # z: [bs x nvars x (context_window+target_window)]
        # split data 
        z_hist = z[:,:,:self.context_window]                                                # z_hist: [bs x nvars x context_window]
        z = z[:,1:,self.context_window:]                                                    # z: [bs x nvars-1 x target_window]
        
        if self.head_type == 'flatten': 
            z_hist = z_hist[:,0,:]                                                          # z_hist: [bs x 1 x context_window]
            z_hist = z_hist.reshape(z_hist.size(0),-1)                                      # z_hist: [bs x context_window]
            z = z.reshape(z.size(0),-1)                                                     # z: [bs x nvars-1 * target_window]
            z = torch.cat((z,z_hist),-1)                                                    # z: [bs x (nvars-1 * target_window + context_window)]
            z = self.head(z)
        elif self.head_type == 'flatten_full': 
            z_hist = z_hist.reshape(z_hist.size(0),-1)                                      # z_hist: [bs x nvars * context_window]
            z = z.reshape(z.size(0),-1)                                                     # z: [bs x nvars-1 * target_window]
            z = torch.cat((z,z_hist),-1)                                                    # z: [bs x (nvars-1 * target_window + nvars * context_window)]
            z = self.head(z)
        elif self.head_type == 'flatten_mlp': 
            z_hist = z_hist.reshape(z_hist.size(0),-1)                                      # z_hist: [bs x nvars * context_window]
            z = z.reshape(z.size(0),-1)                                                     # z: [bs x nvars-1 * target_window]
            z = torch.cat((z,z_hist),-1)                                                    # z: [bs x (nvars-1 * target_window + nvars * context_window)]
            z = self.head(z)
        elif self.head_type == 'linear': 
            z = self.head(z)
        elif self.head_type == 'mlp': 
            z = self.head(z)
        
        return z

    
class Flatten_mlp_Head(nn.Module):
    def __init__(self, c_out, mlp_dim_1, target_window, nf, head_dropout=0):
        super().__init__()

        self.c_out = c_out
        self.target_window = target_window
        
        self.linear1 = nn.Linear(nf, mlp_dim_1)
        self.linear2 = nn.Linear(mlp_dim_1, c_out*target_window)
        self.dropout = nn.Dropout(head_dropout)
            
    def forward(self, x):                     # x: [bs x (nvars-1 * target_window + context_window)]
        # x = self.flatten(x)                
        x = F.relu(self.linear1(x))           # x: [bs x mlp_dim_1]
        x = self.dropout(x)
        x = self.linear2(x)                   # x: [bs x c_out * target_window]
        x = torch.reshape(x, (x.shape[0],self.c_out,self.target_window))
        return x

    
class Flatten_Head(nn.Module):
    def __init__(self, c_out, target_window, nf, head_dropout=0):
        super().__init__()

        self.c_out = c_out
        self.target_window = target_window
        
        self.linear1 = nn.Linear(nf, c_out*target_window)
        self.dropout = nn.Dropout(head_dropout)
            
    def forward(self, x):                     # x: [bs x (nvars-1 * target_window + context_window)]
        # x = self.flatten(x)                
        x = self.linear1(x)                    # x: [bs x c_out * target_window]
        x = self.dropout(x)
        x = torch.reshape(x, (x.shape[0],self.c_out,self.target_window))
        return x
    
    
class Linear_Head(nn.Module):
    def __init__(self, d_model, c_out, head_dropout=0):
        super().__init__()

        self.linear = nn.Linear(d_model, c_out)
        self.dropout = nn.Dropout(head_dropout)
            
    def forward(self, x):                     # x: [bs x d_model x target_window]
        x = x.permute(0,2,1)                  # x: [bs x target_window x d_model]
        x = self.linear(x)                    # x: [bs x target_window x c_out]
        x = self.dropout(x)
        x = x.permute(0,2,1)                  # x: [bs x c_out x target_window]
        return x

    
class MLP_Head(nn.Module):
    def __init__(self, d_model, c_out, mlp_dim_1=64, head_dropout=0):
        super().__init__()

        self.linear = nn.Linear(d_model, mlp_dim_1)
        self.linear1 = nn.Linear(mlp_dim_1, c_out)
        self.dropout = nn.Dropout(head_dropout)
            
    def forward(self, x):                     # x: [bs x d_model x target_window]
        x = x.permute(0,2,1)                  # x: [bs x target_window x d_model]
        x = F.relu(self.linear(x))            # x: [bs x target_window x mlp_dim_1]
        x = self.dropout(x)
        x = self.linear1(x)                   # x: [bs x target_window x c_out]
        x = x.permute(0,2,1)                  # x: [bs x c_out x target_window]
        return x    