from curses import window
import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import reduce


class SBlock(nn.Module):
    """Spatial block to convert time to space. """

    def __init__(self, 
        input_dim: int = 1, 
        out_channels: int = 1, 
        #hidden_dim: int=25, 
        #hidden_shape: list=[5,5],
        kernel_size: int=3, 
        output_dim: list=[3, 3],
        window_size: int = 80) -> None:
        super().__init__()
        self.output_dim = output_dim
        self.hidden_shape = [dim+2 for dim in output_dim]
        self.hidden_dim = reduce((lambda x, y: x * y), self.hidden_shape)
        self.input_dim = input_dim
        self.out_channels = out_channels
        self.conv_1d = nn.Conv1d(in_channels=input_dim, out_channels=out_channels, kernel_size=kernel_size, padding='same')
        self.hidden_layer = nn.Linear(window_size, self.hidden_dim)
        self.conv_2d = nn.Conv2d(input_dim, out_channels, kernel_size=3)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_1d(x)
        x = self.hidden_layer(x)
        x = torch.reshape(x, [self.input_dim, self.out_channels] + self.hidden_shape)
        x = self.conv_2d(x)
        x = torch.sigmoid(x) / torch.sum(abs(x))
        return x        

class RBlock(nn.Module):
    """Recurrent block to process time series."""

    def __init__(self,
        input_dim: int = 1,
        hidden_size: int = 10, 
        num_layers: int = 1,
        forecast_size: int = 20,
        kernel_size: int = 3,
        stride : int = 1,
        batch_size : int = 1
        ) -> None:
        super().__init__()
        self.batch_size = batch_size
        self.input_dim = input_dim
        self.forecast_size = forecast_size
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_size, num_layers=num_layers)
        self.adj_lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_size, num_layers=num_layers)

        self.adj_conv = nn.Conv3d(in_channels=hidden_size, out_channels=1, kernel_size=kernel_size, stride=stride, padding=(1, 0, 0))
        self.hidden2output = nn.Linear(in_features=hidden_size*2, out_features = forecast_size)

    def forward(self, x: torch.Tensor, x_adj: torch.Tensor) -> torch.Tensor:

        # 3D Convolution of Adjacency Matrix. 
        x_adj = self.adj_conv(x_adj)
        x_adj = x_adj.reshape(-1)
        x_adj = x_adj.reshape(self.batch_size, x_adj.shape[0], self.input_dim)
        x_adj, _ = self.adj_lstm(x_adj)

        # Self-processing of time series. 
        x_self, _ = self.lstm(x)

        # Concatenation of the last set of hidden states. 
        last_adj = x_adj[:, -1, :]
        last_self = x_self[:, -1, :]
        last = torch.concat([last_adj, last_self], dim = 1)

        # prediction using the last set of the hidden states. 
        x = self.hidden2output(last)
        x = x.reshape(self.batch_size, self.input_dim, self.forecast_size)

        return x
    
    def hidden_state(self, x: torch.Tensor) -> torch.Tensor:
        """Returns hidden state of individual network. 

        Args:
            x (torch.Tensor): Tensor containing input. 

        Returns:
            torch.Tensor: Hidden state of Tensor output. 
        """
        x_self, _ = self.lstm(x)
        return x_self
    
    def latent_space(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.lstm(x)
        x = x[:, -1, x]
    
    def copy_adj_lstm(self, other):
        self.adj_lstm = other.adj_lstm

class RMatrix(nn.Module):
    """Matrix to store Rblocks. """
    
    def __init__(self, 
        block_size: list = [3, 3],
        model_path: str = None,
        window_size: int = 80,
        forecast_size: int = 20) -> None:
        """

        Args:
            block_size (list, optional): Spatial dimension of the R-Matrix. Defaults to [3, 3].
            model_path (str, optional): Model path for saving. Defaults to None.
            window_size (int, optional): Window size. Defaults to 80.
            forecast_size (int, optional): Forecast size. Defaults to 20.
        """
        super().__init__()
        self.width, self.height = block_size
        self.window_size = window_size
        self.forecast_size = forecast_size
        self.model_path = model_path
        self.rmatrix = list()
        for _ in range(self.width):
            rrow = list()
            for _ in range(self.height):
                rrow.append(RBlock(forecast_size=forecast_size))
            self.rmatrix.append(rrow)


    def get_rblock(self, coordinate: list)->RBlock:
        """Get R Block associated with a given coordinate. 

        Args:
            coordinate (list): A tuple of coordinates to get the R block from. 

        Returns:
            RBlock: Rblock associated with the coordinate in the R matrix. 
        """
        
        return self.rmatrix[coordinate[0]][coordinate[1]]
    
    def adjacency_lstm_copy(self, coordinate : list):
        """Copy adjacent network weights to improve prediction accuracy and reduce redundant learning. 
        Described in 'Ensemble Weight Sharing' section. 

        Args:
            coordinate (list): Copy R-Block adjacency network weights to adjacent networks. 
        """
        best_rblock = self.get_rblock(coordinate=coordinate)
        for i in range(self.width):
            for j in range(self.height):
                rblock = self.rmatrix[i][j]
                rblock.copy_adj_lstm(best_rblock)
                self.rmatrix[i][j] = rblock
    
    def adjacency(self, x: torch.Tensor) -> torch.Tensor:
        """Generates the adjacency input for each R block in the R matrix. 

        Args:
            x (torch.Tensor): Input associated with time series. 

        Returns:
            torch.Tensor: Hidden States associated with adjacent R blocks. 
        """
        x_adj = list()
        for i in range(self.width):
            rrow = list()
            for j in range(self.height):
                rblock = self.rmatrix[i][j]
                x_out = rblock.hidden_state(x)
                rrow.append(x_out)
            rrow = torch.cat(rrow)
            x_adj.append(rrow)
        x_adj = torch.cat(x_adj)
        x_adj = x_adj.view(self.width, self.height, x_adj.shape[-2], x_adj.shape[-1])
        return x_adj
    
    def forward(self, x: torch.Tensor, x_adj: torch.Tensor) -> torch.Tensor:
        """Generates predictions for each R-block for the time series. 

        Args:
            x (torch.Tensor): Input vector. 

        Returns:
            torch.Tensor: Output for each 
        """
        preds = list()
        for i in range(self.width):
            rrow = list()
            for j in range(self.height):
                rblock = self.rmatrix[i][j]
                x_out = rblock(x, x_adj)
                rrow.append(x_out)
            rrow = torch.cat(rrow)
            preds.append(rrow)
        preds = torch.cat(preds)
        preds = preds.view(self.width, self.height, preds.shape[-2], preds.shape[-1])
        return preds
    
    def save(self, coordinate : list=None):
        """Method to save network weights. 

        Args:
            coordinate (list, optional): Saves network weights of specific coordinate. Defaults to None.
        """
        if(coordinate): 
            i, j = coordinate
            rblock = self.rmatrix[i][j]
            torch.save(rblock.state_dict(), self.model_path + f'{i}_{j}.pth')
        else:
            for i in range(self.width):
                for j in range(self.height):
                    rblock = self.rmatrix[i][j]
                    torch.save(rblock.state_dict(), self.model_path + f'{i}_{j}.pth')
    
    def load(self, coordinate : list=None):
        """Method to load network weights.

        Args:
            coordinate (list, optional): Load network weights of specific coordinate. Defaults to None.
        """
        if(coordinate): 
            i, j = coordinate
            rblock = self.rmatrix[i][j]
            rblock.load_state_dict(self.model_path + f'{i}_{j}.pth')
        else:
            for i in range(self.width):
                for j in range(self.height):
                    rblock = self.rmatrix[i][j]
                    rblock.load_state_dict(self.model_path + f'{i}_{j}.pth')
                
    def to_double(self):
        """Cast network to double. 
        """
        for i in range(self.width):
            for j in range(self.height):
                self.rmatrix[i][j] = self.rmatrix[i][j].double()
                
    def to_float(self):
        """Cast network to float. 
        """
        for i in range(self.width):
            for j in range(self.height):
                self.rmatrix[i][j] = self.rmatrix[i][j].float()