import copy
import torch
from torch import nn
import numpy as np
import pywt
# new package
from braindecode.models import Deep4Net
from woods.layers.FrequencyFilter import FrequencyFilter

#########################
## EEG / Signal Models ##
#########################
class deep4(nn.Module):
    """ The DEEP4 model

    This is from the Braindecode package:
        https://github.com/braindecode/braindecode
    
    Args:
        dataset (Multi_Domain_Dataset): dataset that we will be training on
        model_hparams (dict): The hyperparameters for the model.

    Attributes:
        input_size (int): The size of the inputs to the model (for a single time step).
        output_size (int): The size of the outputs of the model (number of classes).
        seq_len (int): The length of the sequences.
    """
    
    def __init__(self, dataset, model_hparams):
        super(deep4, self).__init__()
        self.dataset = dataset
        # Save stuff
        self.device = model_hparams['device']
        self.input_size = np.prod(dataset.INPUT_SHAPE)
        self.output_size = dataset.OUTPUT_SIZE
        self.seq_len = dataset.SEQ_LEN
        self.feat_dim = model_hparams['feat_dim']

        self.model = Deep4Net(
            self.input_size,
            self.output_size,
            input_window_samples=self.seq_len,
            final_conv_length='auto',
            n_filters_time=32,
            n_filters_spat=32,
            filter_time_length=10,
            pool_time_length=3,
            pool_time_stride=3,
            n_filters_2=64,
            filter_length_2=10,
            n_filters_3=128,
            filter_length_3=10,
            n_filters_4=256,
            filter_length_4=10
        )

        # Delete undesired layers
        self.classifier = copy.deepcopy(self.model.conv_classifier)
        del self.model.conv_classifier
        del self.model.softmax
        del self.model.squeeze
        
    def forward(self, input):

        # Forward pass
        features = self.model(input.permute((0, 2, 1)))
        out = self.classify(features)

        # Remove all extra dimension and Add the time prediction dimension
        out, features = torch.flatten(out, start_dim=1), torch.flatten(features, start_dim=1)
        out, features = out.unsqueeze(1), features.unsqueeze(1)

        return out, features

    def classify(self, features):
        features = features.reshape(features.shape[0], 256, -1, 1)
        output = self.classifier(features)
        output = torch.flatten(output, start_dim=1).unsqueeze(1)
        return output

    def get_classifier_network(self):
        return self.classifier
    
class deep4_Freq(deep4):
    """ The DEEP4 model

    This is from the Braindecode package:
        https://github.com/braindecode/braindecode
    
    Args:
        dataset (Multi_Domain_Dataset): dataset that we will be training on
        model_hparams (dict): The hyperparameters for the model.

    Attributes:
        input_size (int): The size of the inputs to the model (for a single time step).
        output_size (int): The size of the outputs of the model (number of classes).
        seq_len (int): The length of the sequences.
    """
    
    def __init__(self, dataset, model_hparams):
        super(deep4_Freq, self).__init__(dataset, model_hparams)

        # Save stuff
        self.mask_spectrum = self._get_mask_spectrum(alpha = model_hparams['alpha'], freq_type=model_hparams['freq_type'])
        self.disentanglement = FrequencyFilter(self.mask_spectrum, freq_type=model_hparams['freq_type'])
    
    def _get_mask_spectrum(self, alpha, freq_type):
        """
        get shared frequency spectrums
        """
        loader_names, train_loaders = self.dataset.get_train_loaders()
        amps = 0.0
        for name, train_loader in zip(loader_names, train_loaders):
            loader_len = len(train_loader)
            print("loader name:", name, " trainloaderlen:", loader_len)
            for i, data in enumerate(train_loader):
                lookback_window = data[0]
                B, L, C = lookback_window.shape
                # print(lookback_window.shape)
                frequency_feature = None
                if freq_type == "fft":
                    frequency_feature = torch.fft.rfft(lookback_window, dim=1)
                elif freq_type in ['db2', 'sym2', 'coif1', 'bior1.3', 'rbio1.3']:
                    wavelet = pywt.Wavelet(freq_type)
                    # print("ortho=", wavelet.orthogonal)
                    lookback_window = lookback_window.permute(0,2,1)
                    device = lookback_window.device
                    X = lookback_window.numpy()
                    cA, cD = pywt.dwt(X, wavelet)
                    frequency_feature = np.concatenate((cA, cD), axis=2).transpose((0,2,1)) # B D C
                    frequency_feature = torch.from_numpy(frequency_feature).to(device)

                assert frequency_feature != None

                amps += abs(frequency_feature).mean(dim=0).mean(dim=1)
                if i > loader_len:
                    break

        mask_spectrum = amps.topk(int(amps.shape[0]*alpha)).indices
        print("mask_spectrum:", mask_spectrum)
        return mask_spectrum # as the spectrums of time-invariant component
    
    def forward(self, input):
        _, x_det = self.disentanglement(input)
        input = x_det
        # Forward pass
        features = self.model(input.permute((0, 2, 1)))
        out = self.classify(features)

        # Remove all extra dimension and Add the time prediction dimension
        out, features = torch.flatten(out, start_dim=1), torch.flatten(features, start_dim=1)
        out, features = out.unsqueeze(1), features.unsqueeze(1)

        return out, features

    def classify(self, features):
        features = features.reshape(features.shape[0], 256, -1, 1)
        output = self.classifier(features)
        output = torch.flatten(output, start_dim=1).unsqueeze(1)
        return output

    def get_classifier_network(self):
        return self.classifier