import models.layers as layers
import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
import torch
from itertools import chain
import numpy as np

class ConcreteAutoencoder(pl.LightningModule):
    '''MLP with input layer selection.

    Args:
      input_layer: input layer type (e.g., 'concrete_gates').
      input_size: number of inputs.
      output_size: number of outputs.
      hidden: list of hidden layer widths.
      activation: nonlinearity between hidden layers.
      output_activation: nonlinearity at output.
      kwargs: additional arguments (e.g., k, init, append). Some are optional,
        but k is required for ConcreteMask and ConcreteGates.
    '''
    def __init__(self,
                 input_size,
                 output_size,
                 start_temperature,
                 end_temperature,
                 train_length,
                 mbsize,
                 max_epochs,
                 hidden,
                 k,
                 lr,
                 loss_fn):
        super().__init__()

        self.input_layer = layers.ConcreteSelector(input_size, k=k)
        self.input_layer.temperature = start_temperature
        self.lr = lr

        fc_layers = []
        for d_in, d_out in zip([k] + hidden, hidden + [output_size]):
            fc_layers.append(nn.Linear(d_in, d_out))
            fc_layers.append(nn.ReLU())
        fc_layers = fc_layers[:-1]

        self.r = np.power(end_temperature / start_temperature, 1 / ((train_length // mbsize) * max_epochs))

        self.fc = nn.Sequential(*fc_layers)
        self.loss_fn = loss_fn
        self.avg_gate_vals = 0

    def forward(self, x, **kwargs):
        test = self.input_layer(x)
        return self.fc(test)

    def on_train_batch_start(self, batch, batch_idx: int, dataloader_idx: int):
        if self.avg_gate_vals > 0.99:
            return -1

    def training_step(self, batch, batch_idx):
        x, y = batch

        output = self(x)
        self.input_layer.temperature *= self.r

        loss = self.loss_fn(output, y)
        self.log('Loss', loss, prog_bar=True)

        M = self.input_layer.sample(n_samples=256)
        values = torch.mean(M, dim=0)
        self.avg_gate_vals = torch.max(values, dim=1).values.mean()

        self.log('AvgGateVals', self.avg_gate_vals, prog_bar=True)

        return loss

    def get_inds(self, **kwargs):
        return self.input_layer.get_inds(**kwargs)

    def configure_optimizers(self):
        optimizer = optim.Adam(
            chain(self.input_layer.parameters(), self.fc.parameters()),
            lr=self.lr)
        return optimizer
