
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.distributions as dist
from woods.layers.VAE_EncDec import Encoder, Decoder
from woods.layers.Transformer_EncDec import Encoder as TransEncoder
from woods.layers.Transformer_EncDec import EncoderLayer as TransEncoderLayer
from woods.layers.SelfAttention_Family import FullAttention, AttentionLayer
from woods.layers.Embed import PatchEmbedding
from woods.layers.FrequencyFilter import FrequencyFilter
import numpy as np
import pywt

class PriorBlock(nn.Module):
    def __init__(self, d_dim, zd_dim):
        super(PriorBlock, self).__init__()
        self.d_dim = d_dim

        self.fc1 = nn.Sequential(
            nn.Linear(d_dim, zd_dim, bias=False), 
            nn.BatchNorm1d(zd_dim), 
            nn.ReLU()
        )

        self.fc21 = nn.Sequential(
            nn.Linear(zd_dim, zd_dim)
        )

        self.fc22 = nn.Sequential(
            nn.Linear(zd_dim, zd_dim), 
            nn.Softplus()
        )

        self.init_weights()

    def init_weights(self):
        torch.nn.init.xavier_uniform_(self.fc1[0].weight)
        torch.nn.init.xavier_uniform_(self.fc21[0].weight)
        self.fc21[0].bias.data.zero_()
        torch.nn.init.xavier_uniform_(self.fc22[0].weight)
        self.fc22[0].bias.data.zero_()

    def forward(self, d):
        d_onehot = torch.zeros(d.shape[0], self.d_dim)
        for idx, val in enumerate(d):
            d_onehot[idx][val.item()] = 1

        d_onehot = d_onehot.to(d.device)

        hidden = self.fc1(d_onehot)
        zd_loc = self.fc21(hidden)
        zd_scale = self.fc22(hidden) + 1e-7

        return zd_loc, zd_scale

# Auxiliary tasks
class Projection(nn.Module):
    def __init__(self, d_dim, zd_dim):
        super(Projection, self).__init__()

        self.fc1 = nn.Linear(zd_dim, d_dim)

        torch.nn.init.xavier_uniform_(self.fc1.weight)
        self.fc1.bias.data.zero_()

    def forward(self, zd):
        h = F.relu(zd)
        loc_d = self.fc1(h)
        return loc_d


class TimeStochasticBlock(nn.Module):
    def __init__(self, len_seq, n_features, hidden_size, model_hparams) -> None:
        super(TimeStochasticBlock, self).__init__()
        self.len_seq = len_seq
        self.enc_in = n_features
        # hidden_dim = hidden_size
        self.hidden_dim = model_hparams['hidden_size']
        patch_len = model_hparams['patch_len']
        stride = model_hparams['stride']
        padding = stride
        d_model = model_hparams['d_model']
        factor = model_hparams['factor']
        dropout = model_hparams['dropout']
        d_ff = model_hparams['d_ff']
        n_heads = model_hparams['n_heads']
        output_attention = model_hparams['output_attention']
        activation = model_hparams['activation']
        e_layers = model_hparams['e_layers']

        # patching and embedding
        self.patch_embedding = PatchEmbedding(
            d_model, patch_len, stride, padding, dropout)

        # Encoder
        self.encoder = TransEncoder(
            [
                TransEncoderLayer(
                    AttentionLayer(
                        FullAttention(False, factor, attention_dropout=dropout,
                                      output_attention=output_attention), d_model, n_heads),
                    d_model,
                    d_ff,
                    dropout=dropout,
                    activation=activation
                ) for l in range(e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(d_model)
        )

        # Prediction Head
        self.head_nf = d_model * \
                       int((self.len_seq - patch_len) / stride + 2)

        self.flatten = nn.Flatten(start_dim=-2)
        self.dropout = nn.Dropout(dropout)
        self.projection = nn.Linear(
            self.head_nf * self.enc_in, hidden_size)
        
    def forward(self, x_enc):
        B, C, L = x_enc.shape
        # u: [bs * nvars x patch_num x d_model]
        enc_out, n_vars = self.patch_embedding(x_enc)

        # Encoder
        # z: [bs * nvars x patch_num x d_model]
        enc_out, attns = self.encoder(enc_out)
        # z: [bs x nvars x patch_num x d_model]
        enc_out = torch.reshape(
            enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
        # z: [bs x nvars x d_model x patch_num]
        enc_out = enc_out.permute(0, 1, 3, 2)

        # Decoder
        output = self.flatten(enc_out)

        output = self.dropout(output)
        output = output.reshape(output.shape[0], -1)
        output = self.projection(output)  # (batch_size, hidden_size)
        return output

    
class FEDNet(nn.Module):
    def __init__(self, dataset, model_hparams, input_size = None):
        super(FEDNet, self).__init__()
        self.dataset = dataset

        # Get some other useful info
        self.mask_spectrum = self._get_mask_spectrum(alpha = model_hparams['alpha'], freq_type = model_hparams['freq_type'])
        self.d_dim = dataset.get_nb_training_domains() #configs.n_domains
        self.y_dim = self.dataset.OUTPUT_SIZE # configs.n_classes
        self.input_channel = np.prod(dataset.INPUT_SHAPE) # configs.n_features
        self.len_seq = self.dataset.SEQ_LEN
        
        self.device = model_hparams['device']
        self.w_det = 1.0 # model_hparams['w_det']
        self.w_sto = 1.0 # model_hparams['w_sto']
        hidden_size = model_hparams['hidden_size']
        fc_dim = model_hparams['fc_dim']
        kernel_size = model_hparams['kernel_size']
        normalize = model_hparams['normalize']
        beta_d = model_hparams['beta_d']
        beta_y = model_hparams['beta_y']
        weight_true = model_hparams['weight_true']
        weight_false = model_hparams['weight_false']
        aux_loss_multiplier_y = model_hparams['aux_loss_multiplier_y']
        aux_loss_multiplier_d = model_hparams['aux_loss_multiplier_d']
        

        self.zd_dim = hidden_size
        self.zx_dim = 0 # keep but not used
        self.zy_dim = hidden_size
        self.fc_dim = fc_dim
        self.kernel_size = kernel_size
        self.freq_type = model_hparams['freq_type']
        self.constraint_type = model_hparams['constraint_type']
        self.temperature = model_hparams['temperature']

        self.z_dim = self.zd_dim + self.zy_dim

        self.disentanglement = FrequencyFilter(self.mask_spectrum, freq_type = model_hparams['freq_type'])

        self.time_stochastic = TimeStochasticBlock(self.len_seq, self.input_channel, hidden_size, model_hparams)

        self.decoder = Decoder(self.z_dim, self.fc_dim, self.input_channel, kernel_size=self.kernel_size, output_channels=[64,128,512,1024])
        self.domain_prior = PriorBlock(self.d_dim, self.zd_dim)
        self.label_prior = PriorBlock(self.y_dim, self.zy_dim)

        self.specific_encoder = Encoder(output_dim=self.zd_dim, input_dim=self.input_channel, fc_dim=self.fc_dim, kernel_size=self.kernel_size, out_channels=[1024, 512, 128, 64])
        self.invariant_encoder = Encoder(output_dim=self.zy_dim, input_dim=self.input_channel, fc_dim=self.fc_dim, kernel_size=self.kernel_size, out_channels=[1024, 512, 128, 64])

        self.domain_projection = Projection(self.d_dim, self.zd_dim)
        self.label_projection = Projection(self.y_dim, self.zy_dim)
        self.sto_projection = Projection(self.y_dim, self.zy_dim)

        self.classifer = Projection(self.y_dim, self.zy_dim*2)

        self.normal = normalize
        self.beta_d = beta_d
        self.beta_y = beta_y
        
        self.weight_true = weight_true
        self.weight_false = weight_false
        self.aux_loss_multiplier_y = aux_loss_multiplier_y
        self.aux_loss_multiplier_d = aux_loss_multiplier_d



    def _get_mask_spectrum(self, alpha, freq_type):
        """
        get shared frequency spectrums
        """
        loader_names, train_loaders = self.dataset.get_train_loaders()
        amps = 0.0
        for name, train_loader in zip(loader_names, train_loaders):
            loader_len = len(train_loader)
            print("loader name:", name, " trainloaderlen:", loader_len)
            for i, data in enumerate(train_loader):
                lookback_window = data[0]
                B, L, C = lookback_window.shape
                # print(lookback_window.shape)
                frequency_feature = None
                if freq_type == "fft":
                    frequency_feature = torch.fft.rfft(lookback_window, dim=1)
                elif freq_type in ['db2', 'sym2', 'coif1', 'bior1.3', 'rbio1.3']:
                    wavelet = pywt.Wavelet(freq_type)
                    # print("ortho=", wavelet.orthogonal)
                    lookback_window = lookback_window.permute(0,2,1)
                    device = lookback_window.device
                    X = lookback_window.numpy()
                    cA, cD = pywt.dwt(X, wavelet)
                    frequency_feature = np.concatenate((cA, cD), axis=2).transpose((0,2,1)) # B D C
                    frequency_feature = torch.from_numpy(frequency_feature).to(device)

                assert frequency_feature != None

                amps += abs(frequency_feature).mean(dim=0).mean(dim=1)
                if i > loader_len:
                    break

        mask_spectrum = amps.topk(int(amps.shape[0]*alpha)).indices
        print("mask_spectrum:", mask_spectrum)
        return mask_spectrum # as the spectrums of time-invariant component
    
    def normalize(self, x):
        # x_enc: B L C
        # Normalize 
        mean_enc = x.mean(1, keepdim=True).detach() # B x 1 x E
        x = x - mean_enc
        std_enc = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
        x = x / std_enc 
        return x
    
    def forward(self, x, y=None, d=None):
        if self.training:
            assert y is not None and d is not None
            x_sto, x_det = self.disentanglement(x)
            if self.normal:
                x_sto = self.normalize(x_sto)
                x_det = self.normalize(x_det)

            x_sto = x_sto.permute(0, 2, 1)
            x_det = x_det.permute(0, 2, 1)

            x_sto = self.time_stochastic(x_sto)
            x_sto_label = self.sto_projection(x_sto)
            y = y.squeeze(1)
            
            loss_sto = F.cross_entropy(x_sto_label, y, reduction='sum')
            
            x_recon, d_hat, y_hat, qzd, pzd, z_spc, qzy, pzy, z_inv = self.forward_train(d, x_det, y)
            
            if x_recon.shape[-1] > x_det.shape[-1]:
                x_recon = x_recon[:, :, :x_det.shape[-1]]
            # reconstruction loss
            reconstruction_loss = F.mse_loss(x_recon, x_det.float())
            # print(qzd.log_prob(z_spc).shape)
            # KL divergence between q(z|x) and p(z)
            domain_kl_divergence = torch.sum(pzd.log_prob(z_spc) - qzd.log_prob(z_spc))
            label_kl_divergence = torch.sum(pzy.log_prob(z_inv) - qzy.log_prob(z_inv))
            # Auxiliary losses
            domain_cls = F.cross_entropy(d_hat, d, reduction='sum')
            label_cls = F.cross_entropy(y_hat, y, reduction='sum')

            loss_ELOB = reconstruction_loss - self.beta_d * domain_kl_divergence - self.beta_y * label_kl_divergence
            auxil_loss = self.aux_loss_multiplier_d * domain_cls + self.aux_loss_multiplier_y * label_cls

            cross_loss = None
            if self.constraint_type == 'cross':
                cross_loss = self.cross_label_loss(x_det)
                constraint_loss = auxil_loss + cross_loss
            else:
                contrastive_loss = self.contrastive_loss(z_inv, z_spc, temperature=self.temperature)
                constraint_loss = auxil_loss + contrastive_loss
            
            latent = torch.concat([z_inv,x_sto],dim=-1)
            
            loss_det = loss_ELOB + constraint_loss

            y_hat_latent = self.classifer(latent)

            loss_cls = F.cross_entropy(y_hat_latent, y, reduction='sum')

            return y_hat_latent, self.w_det * loss_det + self.w_sto * loss_sto + loss_cls, cross_loss
        else:
            return self.forward_test(x)

    def forward_train(self, d, x, y):
        # Encode q(z|x)
        zd_mean, zd_scale, _, _ = self.specific_encoder(x)
        zy_mean, zy_scale, idxs_y, sizes_y = self.invariant_encoder(x)

        # Reparameterization trick
        domain_space = dist.Normal(zd_mean, zd_scale) # domain posterior distribution
        domain_feature = domain_space.rsample() # z_spc = mu_spc + sigma_spc * epsilon
        invariant_space = dist.Normal(zy_mean, zy_scale) # invariant posterior distribution
        invariant_feature = invariant_space.rsample() # z_inv = mu_inv + sigma_inv * epsilon

        # Decode p(x|z)
        z = torch.cat((domain_feature, invariant_feature), dim=-1)

        x_recon = self.decoder(z, idxs_y, sizes_y)

        # Prior Distributions learn from label data d, y
        domain_prior_mean, domain_prior_scale = self.domain_prior(d)
        label_prior_mean, label_prior_scale = self.label_prior(y)

        domain_prior_distribution = dist.Normal(domain_prior_mean, domain_prior_scale)
        label_prior_distribution = dist.Normal(label_prior_mean, label_prior_scale)

        # Auxiliary losses
        d_hat = self.domain_projection(domain_feature)
        y_hat = self.label_projection(invariant_feature)

        return x_recon, d_hat, y_hat, domain_space, domain_prior_distribution, domain_feature, invariant_space, label_prior_distribution, invariant_feature

    def contrastive_loss(self, z_inv, z_spc, temperature=0.07):
        eps = 1e-6
        device = z_inv.device
        z_inv = F.normalize(z_inv, dim=1)
        z_spc = F.normalize(z_spc, dim=1)
        
        features = torch.cat([z_inv.unsqueeze(1), z_spc.unsqueeze(1)], dim=1).to(device)
        labels = torch.cat([torch.ones(z_inv.shape[0]), torch.zeros(z_spc.shape[0])]).to(device)
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)

        batch_size = features.shape[0]
        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        anchor_feature = contrast_feature
        anchor_count = contrast_count
        
        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        logits_mask = torch.scatter(
            torch.ones_like(mask).to(device),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)+eps)

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1)+eps)

        loss = - mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()
        # print("con:", loss)
        return loss

    def cross_label_loss(self, x):
        with torch.no_grad():
            zd_mean, zd_scale, _, _ = self.specific_encoder(x)
            zd = zd_mean
            domain_prob = F.softmax(self.domain_projection(zd), dim=1)
            # the maximum predicted class probability
            _, ind = torch.topk(domain_prob, 1) 
            # convert the digit(s) to one-hot tensor(s)
            d = x.new_zeros(domain_prob.size()) 
            pred_d = d.scatter_(1, ind, 1.0)


            zy_mean, zy_scale, _, _ = self.invariant_encoder(x)
            zy = zy_mean
            label_prob = F.softmax(self.label_projection(zy), dim=1)
            _, ind = torch.topk(label_prob, 1)
            y = x.new_zeros(label_prob.size())
            pred_y = y.scatter_(1, ind, 1.0)

            # cross projection for constraint loss
            alpha_y2d = F.softmax(self.domain_projection(zy), dim=1)
            _, ind = torch.topk(alpha_y2d, 1)
            d_false = x.new_zeros(alpha_y2d.size())
            pred_d_false = d_false.scatter_(1, ind, 1.0)

            alpha_d2y = F.softmax(self.label_projection(zd), dim=1)
            _, ind = torch.topk(alpha_d2y, 1)
            y_false = x.new_zeros(alpha_d2y.size())
            pred_y_false = y_false.scatter_(1, ind, 1.0)

        loss_classify_true = self.weight_true * (F.cross_entropy(pred_d, d, reduction='sum') + F.cross_entropy(pred_y, y, reduction='sum'))
        loss_classify_false = self.weight_false * (F.cross_entropy(pred_d_false, d, reduction='sum') + F.cross_entropy(pred_y_false, y, reduction='sum'))

        loss = loss_classify_true - loss_classify_false
        loss.requires_grad = True

        return loss
    
    def forward_test(self, x):
        B = x.shape[0]
        x_sto, x_det = self.disentanglement(x)
        if self.normal:
            x_sto = self.normalize(x_sto)
            x_det = self.normalize(x_det)

        x_sto = x_sto.permute(0, 2, 1)
        x_det = x_det.permute(0, 2, 1)
        
        x_sto = self.time_stochastic(x_sto)

        zy_mean, zy_scale, _, _ = self.invariant_encoder(x_det)
        z_inv = zy_mean
        latent = torch.concat([z_inv,x_sto],dim=-1)

        label_prob = F.softmax(self.classifer(latent), dim=1)
        _, ind = torch.topk(label_prob, 1)
        y = x.new_zeros(label_prob.size())
        pred_y = y.scatter_(1, ind, 1.0)

        logits = pred_y.unsqueeze(1)
        features = latent.reshape(B, 1, -1)
        return logits, features



