import torch
import torch.nn as nn
import numpy as np
import torch.nn.utils.weight_norm as weightNorm
from woods.layers.FrequencyFilter import FrequencyFilter
import pywt

class FeatNetwork(nn.Module):
    def __init__(self, in_channels, model_hparams):
        super(FeatNetwork, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=16, kernel_size=(
                1, model_hparams['ker_size'])),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(
                1, model_hparams['ker_size'])),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.n_features = model_hparams['fc_size']

    def forward(self, x):
        x = self.conv2(self.conv1(x))
        # print(x.shape)
        x = x.view(-1, self.n_features)
        return x

class FeatClassifier(nn.Module):
    def __init__(self, class_num, bottleneck_dim=256, type="linear"):
        super(FeatClassifier, self).__init__()
        self.type = type
        if type == 'wn':
            self.fc = weightNorm(
                nn.Linear(bottleneck_dim, class_num), name="weight")
        else:
            self.fc = nn.Linear(bottleneck_dim, class_num)

    def forward(self, x):
        x = self.fc(x)
        return x

class ActNetwork(nn.Module):
    """
    Empirical Risk Minimization (ERM)
    """

    def __init__(self, dataset, model_hparams, input_size=None):
        super(ActNetwork, self).__init__()
        self.device = model_hparams['device']
        self.dataset = dataset
        self.input_channel = np.prod(self.dataset.INPUT_SHAPE) if input_size is None else input_size
        self.featurizer = FeatNetwork(self.input_channel, model_hparams)
        self.classifier = FeatClassifier(
            self.dataset.OUTPUT_SIZE, self.featurizer.n_features, model_hparams['classifier_type'])
        self.feat_dim = self.featurizer.n_features
    def forward(self, x):
        B = x.shape[0]
        if x.dim() < 4:
            B, L, C = x.shape
            x = x.view(B, C, 1, L)
        feats = self.featurizer(x)
        logits = self.classifier(feats)
        logits = logits.reshape(B, 1, -1)
        feats = feats.reshape(B, 1, -1)
        # print("log", logits.shape, "feat", feats.shape)
        return logits, feats

class ActFreqNetwork(ActNetwork):
    def __init__(self, dataset, model_hparams, input_size=None):
        super(ActFreqNetwork, self).__init__(dataset, model_hparams, input_size)
        self.mask_spectrum = self._get_mask_spectrum(alpha = model_hparams['alpha'], freq_type = model_hparams['freq_type'])
        self.disentanglement = FrequencyFilter(self.mask_spectrum, freq_type = model_hparams['freq_type'])

    def _get_mask_spectrum(self, alpha, freq_type):
        """
        get shared frequency spectrums
        """
        loader_names, train_loaders = self.dataset.get_train_loaders()
        amps = 0.0
        for name, train_loader in zip(loader_names, train_loaders):
            loader_len = len(train_loader)
            print("loader name:", name, " trainloaderlen:", loader_len)
            for i, data in enumerate(train_loader):
                lookback_window = data[0]
                B, L, C = lookback_window.shape
                frequency_feature = None
                if freq_type == "fft":
                    frequency_feature = torch.fft.rfft(lookback_window, dim=1)
                elif freq_type in ['db2', 'sym2', 'coif1', 'bior1.3', 'rbio1.3']:
                    wavelet = pywt.Wavelet(freq_type)
                    # print("ortho=", wavelet.orthogonal)
                    lookback_window = lookback_window.permute(0,2,1)
                    device = lookback_window.device
                    X = lookback_window.numpy()
                    cA, cD = pywt.dwt(X, wavelet)
                    frequency_feature = np.concatenate((cA, cD), axis=2).transpose((0,2,1)) # B D C
                    frequency_feature = torch.from_numpy(frequency_feature).to(device)

                assert frequency_feature != None

                amps += abs(frequency_feature).mean(dim=0).mean(dim=1)
                if i > loader_len:
                    break

        mask_spectrum = amps.topk(int(amps.shape[0]*alpha)).indices
        print("mask_spectrum:", mask_spectrum)
        return mask_spectrum # as the spectrums of time-invariant component
    
    def forward(self, x):
        B = x.shape[0]

        _, x_det = self.disentanglement(x)
        x = x_det

        if x.dim() < 4:
            B, L, C = x.shape
            x = x.reshape(B, C, 1, L)
        feats = self.featurizer(x)
        logits = self.classifier(feats)
        logits = logits.reshape(B, 1, -1)
        feats = feats.reshape(B, 1, -1)
        # print("log", logits.shape, "feat", feats.shape)
        return logits, feats
    