import models.layers as layers
import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
from itertools import chain
import torch
import numpy as np

class CFS_Joint(pl.LightningModule):
    '''MLP with input layer selection.

    Args:
      input_layer: input layer type (e.g., 'concrete_gates').
      input_size: number of inputs.
      output_size: number of outputs.
      hidden: list of hidden layer widths.
      activation: nonlinearity between hidden layers.
      output_activation: nonlinearity at output.
      kwargs: additional arguments (e.g., k, init, append). Some are optional,
        but k is required for ConcreteMask and ConcreteGates.
    '''
    def __init__(self,
                 input_size,
                 output_size,
                 start_temperature,
                 end_temperature,
                 train_length,
                 mbsize,
                 max_epochs,
                 hidden,
                 k,
                 k_prime,
                 lr,
                 loss_fn):
        super().__init__()

        self.salient_input_layer = layers.ConcreteSelector(input_size, k=k)
        self.salient_input_layer.temperature = start_temperature

        self.lr = lr

        self.background_input_layer = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, k_prime),
        )

        fc_layers = []
        for d_in, d_out in zip([k+k_prime] + hidden, hidden + [output_size]):
            fc_layers.append(nn.Linear(d_in, d_out))
            fc_layers.append(nn.ReLU())
        fc_layers = fc_layers[:-1]

        self.r = np.power(end_temperature / start_temperature, 1 / ((train_length // mbsize) * max_epochs))

        self.fc = nn.Sequential(*fc_layers)
        self.loss_fn = loss_fn
        self.k = k
        self.avg_background_gate_vals = 0
        self.avg_salient_gate_vals = 0

    def forward(self, x, **kwargs):
        test = self.input_layer(x)
        return self.fc(test)

    def on_train_batch_start(self, batch, batch_idx: int, dataloader_idx: int):
        if self.avg_salient_gate_vals > 0.99:
            return -1

    def training_step(self, batch, batch_idx):
        x, y = batch
        x_tar = x[y != 0]
        x_bg = x[y == 0]

        tar_embedding = self.background_input_layer(x_tar)
        bg_embedding = self.background_input_layer(x_bg)

        x_tar_recon = self.fc(torch.cat([self.salient_input_layer(x_tar), tar_embedding], dim=1))
        x_bg_recon = self.fc(torch.cat([
            torch.zeros(x_bg.shape[0], self.k).to(self.device),
            bg_embedding
        ], dim=1))

        self.salient_input_layer.temperature *= self.r

        loss_tar, loss_bg = self.loss_fn(x_tar, x_tar_recon), self.loss_fn(x_bg, x_bg_recon)

        loss = loss_tar + loss_bg

        M = self.salient_input_layer.sample(n_samples=256)
        values = torch.mean(M, dim=0)
        self.avg_salient_gate_vals = torch.max(values, dim=1).values.mean()
        self.log('AvgSalientGateVals', self.avg_salient_gate_vals, prog_bar=True)

        self.log('loss_tar', loss_tar, prog_bar=True)
        self.log('loss_bg', loss_bg, prog_bar=True)
        self.log('loss', loss, prog_bar=True)

        return loss

    def get_inds(self, **kwargs):
        return self.input_layer.get_inds(**kwargs)

    def configure_optimizers(self):
        optimizer = optim.Adam(
            chain(
                self.salient_input_layer.parameters(),
                self.background_input_layer.parameters(),
                self.fc.parameters()
            ), lr=self.lr)
        return optimizer
