import pytorch_lightning as pl
import models.layers as layers
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from itertools import chain
import torch
import numpy as np


class Autoencoder(nn.Module):
    def __init__(self, input_size, k):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, 100),
            nn.ReLU(),
            nn.Linear(100, k),
        )

        self.decoder = nn.Sequential(
            nn.Linear(k, 100),
            nn.ReLU(),
            nn.Linear(100, input_size)
        )

    def forward(self, x):
        embedding = self.encoder(x)
        recon = self.decoder(embedding)

        return embedding, recon


class CFS_Pretrained(pl.LightningModule):
    def __init__(self,
                 input_size,
                 output_size,
                 start_temperature,
                 end_temperature,
                 train_length,
                 mbsize,
                 max_pretrain_epochs,
                 max_fs_epochs,
                 hidden,
                 k,
                 k_prime,
                 lr,
                 loss_fn):
        super().__init__()

        self.input_layer = layers.ConcreteSelector(input_size, k=k)
        self.input_layer.temperature = start_temperature
        self.lr = lr
        self.max_pretrain_epochs = max_pretrain_epochs
        self.k = k

        self.background_autoencoder = Autoencoder(input_size=input_size, k=k_prime)

        fc_layers = []
        for d_in, d_out in zip([k + k_prime] + hidden, hidden + [output_size]):
            fc_layers.append(nn.Linear(d_in, d_out))
            fc_layers.append(nn.ReLU())
        fc_layers = fc_layers[:-1]

        self.r = np.power(end_temperature / start_temperature, 1 / ((train_length // mbsize) * max_fs_epochs))

        self.fc = nn.Sequential(*fc_layers)
        self.loss_fn = loss_fn
        self.avg_salient_gate_vals = 0

    def forward(self, x, **kwargs):
        gated_x = self.input_layer(x)
        bg_representation, _ = self.background_autoencoder(x)
        bg_representation = bg_representation.detach()
        return self.fc(torch.cat([gated_x, bg_representation], dim=1))

    def on_train_batch_start(self, batch, batch_idx: int, dataloader_idx: int):
        if self.avg_salient_gate_vals > 0.99:
            return -1

    def training_step(self, batch, batch_idx):
        x, y = batch

        if self.current_epoch < self.max_pretrain_epochs:
            # When pretraining, only use the background samples
            x = x[y == 0]
            _, recon = self.background_autoencoder(x)
            loss = F.mse_loss(x, recon)
            self.log('bg_recon_loss', loss)
            return loss
        else:
            """
            # After pretraining, only use the target samples
            x = x[y != 0]
            output = self(x)
            self.input_layer.temperature *= self.r

            loss = self.loss_fn(output, x)
            self.log('Loss', loss, prog_bar=True)

            M = self.input_layer.sample(n_samples=256)
            values = torch.mean(M, dim=0)
            self.avg_gate_vals = torch.max(values, dim=1).values.mean()

            self.log('AvgGateVals', self.avg_gate_vals, prog_bar=True)
            """
            x_tar = x[y != 0]
            x_bg = x[y == 0]

            x_tar_recon = self.fc(
                torch.cat([self.input_layer(x_tar), self.background_autoencoder(x_tar)[0].detach()], dim=1))
            x_bg_recon = self.fc(torch.cat([
                torch.zeros(x_bg.shape[0], self.k).to(self.device),
                self.background_autoencoder(x_bg)[0].detach()
            ], dim=1))

            self.input_layer.temperature *= self.r

            loss_tar, loss_bg = self.loss_fn(x_tar, x_tar_recon), self.loss_fn(x_bg, x_bg_recon)
            loss = loss_tar + loss_bg

            M = self.input_layer.sample(n_samples=256)
            values = torch.mean(M, dim=0)
            self.avg_salient_gate_vals = torch.max(values, dim=1).values.mean()
            self.log('AvgSalientGateVals', self.avg_salient_gate_vals, prog_bar=True)

            return loss

    def get_inds(self, **kwargs):
        return self.input_layer.get_inds(**kwargs)

    def configure_optimizers(self):
        optimizer = optim.Adam(
            chain(
                self.background_autoencoder.parameters(),
                self.input_layer.parameters(),
                self.fc.parameters()),
            lr=self.lr)
        return optimizer
