from curses import window
import torch
import torch.nn as nn
import torch.nn.functional as F

from deephfts.mats import coord_matrix
from deephfts.modules import SBlock, RMatrix

class SETR(nn.Module):
    """Jointly composed S and Rblock model. """
    def __init__(self, 
        block_size: list = [3, 3],
        window_size : int = 80,
        forecast_size : int = 20) -> None:
        super().__init__()

        self.window_size = window_size
        self.forecast_size = forecast_size
        self.sblock = SBlock(output_dim=block_size, window_size=window_size)
        self.rmatrix = RMatrix(block_size=block_size, forecast_size=forecast_size, window_size=window_size)
        self.coordmatrix = coord_matrix(x_shape=block_size[0], y_shape=block_size[1])
        self.block_size = block_size

    def get_sblock(self):
        """Return S block
        """
        return self.sblock
    
    def get_rmatrix(self):
        """Return R matrix
        """
        return self.rmatrix
    
    def get_coordmatrix(self):
        """Return matrix of coordinates
        """
        return self.coordmatrix
    
    def get_blocksize(self):
        """Return matrix of coordinates
        """
        return self.block_size

    def extract_space(self, x: torch.Tensor) -> torch.Tensor:
        sblock_output = self.sblock(x)
        return sblock_output
    
    def rmatrix_hidden(self, x: torch.Tensor) -> torch.Tensor:
        """Returns the hidden state matrix for every block in the R-matrix. 

        Args:
            x (torch.Tensor): Input time series. 

        Returns:
            torch.Tensor: Output hidden state. 
        """
        x = x.double()
        rmatrix_hidden = self.rmatrix.adjacency(x)
        rmatrix_hidden = rmatrix_hidden.permute(-1, -2, 0, 1)
        rmatrix_hidden = torch.unsqueeze(rmatrix_hidden, dim=0)
        return rmatrix_hidden
    
    def rmatrix_preds(self, x: torch.Tensor, x_adj: torch.Tensor) -> torch.Tensor:
        """Returns the hidden state matrix for every block in the R-matrix. 

        Args:
            x (torch.Tensor): Input time series. 

        Returns:
            torch.Tensor: Output hidden state. 
        """
        x = x.double()
        x_adj = x_adj.double()
        rmatrix_output = self.rmatrix(x, x_adj)
        return rmatrix_output

    def s_preds(self, x: torch.Tensor)->torch.Tensor:
        """Returns the predictions for the S-Block

        Args:
            x (torch.Tensor): Tensor of input values. 

        Returns:
            [torch.Tensor]: Outputs torch.Tensor prediction. 
        """
        sblock_preds = self.sblock(x)
        return sblock_preds

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """Forward pass. 
        """
        s_preds = self.s_preds(x)
        t = x.permute(0, 2, 1)
        rmatrix_hidden = self.rmatrix_hidden(t)
        rmatrix_preds = self.rmatrix_preds(x=t, x_adj=rmatrix_hidden)
        best_index = torch.argmax(s_preds)
        coord_matrix = self.get_coordmatrix
        flat_coord = torch.flatten(coord_matrix, start_dim=0, end_dim=1)
        best_Rblock = self.flat_coord[best_index].int().tolist()
        current_R_preds = rmatrix_preds[best_Rblock[0],best_Rblock[0], :, : ]
        return s_preds, current_R_preds