import torch
from torch import nn
import numpy as np
import pywt
from woods.layers.FrequencyFilter import FrequencyFilter


##################
## Basic Models ##
##################
class LSTM(nn.Module):
    """ A simple LSTM model

    Args:
        dataset (Multi_Domain_Dataset): dataset that we will be training on
        model_hparams (dict): The hyperparameters for the model.
        input_size (int, optional): The size of the input to the model. Defaults to None. If None, the input size is calculated from the dataset.

    Attributes:
        state_size (int): The size of the hidden state of the LSTM.
        recurrent_layers (int): The number of recurrent layers stacked on each other.
        hidden_depth (int): The number of hidden layers of the classifier MLP (after LSTM).
        hidden_width (int): The width of the hidden layers of the classifier MLP (after LSTM).
    
    Notes:
        All attributes need to be in the model_hparams dictionary.
    """
    def __init__(self, dataset, model_hparams, input_size=None):
        super(LSTM, self).__init__()

        ## Save stuff
        # Model parameters
        self.device = model_hparams['device']
        self.state_size = model_hparams['state_size']
        self.hidden_depth = model_hparams['hidden_depth']
        self.hidden_width = model_hparams['hidden_width']
        self.recurrent_layers = model_hparams['recurrent_layers']

        # Dataset parameters
        self.dataset = dataset
        self.input_size = np.prod(dataset.INPUT_SHAPE) if input_size is None else input_size
        self.output_size = dataset.OUTPUT_SIZE

        ## Recurrent model
        self.lstm = nn.LSTM(self.input_size, self.state_size, self.recurrent_layers, batch_first=True)

        ## Classification model
        layers = []
        if self.hidden_depth == 0:
            layers.append( nn.Linear(self.state_size, self.output_size) )
        else:
            layers.append( nn.Linear(self.state_size, self.hidden_width) )
            for i in range(self.hidden_depth-1):
                layers.append( nn.Linear(self.hidden_width, self.hidden_width) )
            layers.append( nn.Linear(self.hidden_width, self.output_size) )
        
        seq_arr = []
        for i, lin in enumerate(layers):
            seq_arr.append(lin)
            if i != self.hidden_depth:
                seq_arr.append(nn.ReLU(True))
        
        self.feat_dim = self.state_size
        self.classifier = nn.Sequential(*seq_arr)

    def forward(self, input):
        """ Forward pass of the model

        Args:
            input (torch.Tensor): The input to the model.

        Returns:
            torch.Tensor: The output of the model.
        """

        # Get prediction steps
        pred_time = self.dataset.get_pred_time(input.shape)
        # print(input.shape)
        # Setup hidden state
        hidden = self.initHidden(input.shape[0], input.device)

        # Forward propagate LSTM
        input = input.view(input.shape[0], input.shape[1], -1)
        features, hidden = self.lstm(input, hidden)
        # print(features.shape)
        # Extract features at prediction times
        all_features = torch.zeros((input.shape[0], pred_time.shape[0], features.shape[-1])).to(input.device)
        for i, t in enumerate(pred_time):
            all_features[:,i,...] = features[:,t,...]

        # Make prediction with fully connected
        all_out = self.classify(all_features)

        return all_out, all_features

    def classify(self, features):

        n_pred = features.shape[1]
        all_out = torch.zeros((features.shape[0], n_pred, self.output_size)).to(features.device)

        for t in range(n_pred):
            output = self.classifier(features[:,t,:])
            all_out[:,t,...] = output
        
        return all_out

    def initHidden(self, batch_size, device):
        """ Initialize the hidden state of the LSTM with a normal distribution

        Args:
            batch_size (int): The batch size of the model.
            device (torch.device): The device to use.
        """
        return (torch.randn(self.recurrent_layers, batch_size, self.state_size).to(device), 
                torch.randn(self.recurrent_layers, batch_size, self.state_size).to(device))

    def get_classifier_network(self):
        return self.classifier


class LSTM_Freq(LSTM):
    """ A simple LSTM model

    Args:
        dataset (Multi_Domain_Dataset): dataset that we will be training on
        model_hparams (dict): The hyperparameters for the model.
        input_size (int, optional): The size of the input to the model. Defaults to None. If None, the input size is calculated from the dataset.

    Attributes:
        state_size (int): The size of the hidden state of the LSTM.
        recurrent_layers (int): The number of recurrent layers stacked on each other.
        hidden_depth (int): The number of hidden layers of the classifier MLP (after LSTM).
        hidden_width (int): The width of the hidden layers of the classifier MLP (after LSTM).
    
    Notes:
        All attributes need to be in the model_hparams dictionary.
    """
    def __init__(self, dataset, model_hparams, input_size=None):
        super(LSTM_Freq, self).__init__(dataset,model_hparams)
        self.mask_spectrum = self._get_mask_spectrum(alpha = model_hparams['alpha'], freq_type = model_hparams['freq_type'])
        self.disentanglement = FrequencyFilter(self.mask_spectrum, freq_type = model_hparams['freq_type'])
        self.feat_dim = self.state_size
    
    def _get_mask_spectrum(self, alpha, freq_type):
        """
        get shared frequency spectrums
        """
        loader_names, train_loaders = self.dataset.get_train_loaders()
        amps = 0.0
        for name, train_loader in zip(loader_names, train_loaders):
            loader_len = len(train_loader)
            print("loader name:", name, " trainloaderlen:", loader_len)
            for i, data in enumerate(train_loader):
                lookback_window = data[0]
                B, L, C = lookback_window.shape
                # print(lookback_window.shape)
                frequency_feature = None
                if freq_type == "fft":
                    frequency_feature = torch.fft.rfft(lookback_window, dim=1)
                elif freq_type in ['db2', 'sym2', 'coif1', 'bior1.3', 'rbio1.3']:
                    wavelet = pywt.Wavelet(freq_type)
                    # print("ortho=", wavelet.orthogonal)
                    lookback_window = lookback_window.permute(0,2,1)
                    device = lookback_window.device
                    X = lookback_window.numpy()
                    cA, cD = pywt.dwt(X, wavelet)
                    frequency_feature = np.concatenate((cA, cD), axis=2).transpose((0,2,1)) # B D C
                    frequency_feature = torch.from_numpy(frequency_feature).to(device)

                assert frequency_feature != None

                amps += abs(frequency_feature).mean(dim=0).mean(dim=1)
                if i > loader_len:
                    break

        mask_spectrum = amps.topk(int(amps.shape[0]*alpha)).indices
        print("mask_spectrum:", mask_spectrum)
        return mask_spectrum # as the spectrums of time-invariant component
  
    def forward(self, input):
        """ Forward pass of the model

        Args:
            input (torch.Tensor): The input to the model.

        Returns:
            torch.Tensor: The output of the model.
        """
        # Get prediction steps
        pred_time = self.dataset.get_pred_time(input.shape)
        _, x_det = self.disentanglement(input)
        input = x_det
        # Setup hidden state
        hidden = self.initHidden(input.shape[0], input.device)

        # Forward propagate LSTM
        input = input.view(input.shape[0], input.shape[1], -1)
        features, hidden = self.lstm(input, hidden)
        # print(features.shape)
        # Extract features at prediction times
        all_features = torch.zeros((input.shape[0], pred_time.shape[0], features.shape[-1])).to(input.device)
        for i, t in enumerate(pred_time):
            all_features[:,i,...] = features[:,t,...]

        # Make prediction with fully connected
        all_out = self.classify(all_features)

        return all_out, all_features

    def classify(self, features):

        n_pred = features.shape[1]
        all_out = torch.zeros((features.shape[0], n_pred, self.output_size)).to(features.device)

        for t in range(n_pred):
            output = self.classifier(features[:,t,:])
            all_out[:,t,...] = output
        
        return all_out

    def initHidden(self, batch_size, device):
        """ Initialize the hidden state of the LSTM with a normal distribution

        Args:
            batch_size (int): The batch size of the model.
            device (torch.device): The device to use.
        """
        return (torch.randn(self.recurrent_layers, batch_size, self.state_size).to(device), 
                torch.randn(self.recurrent_layers, batch_size, self.state_size).to(device))

    def get_classifier_network(self):
        return self.classifier