import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from argparse import ArgumentParser

class NonLinearApproximator(pl.LightningModule):
    '''
        Class that instantiates a multilayer perceptron to 

        (original word, retrofitted)
        Optimizes ||approximator(x) - retrofitted(x)||^2, or the MSE loss.
    '''
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.input_size = self.hparams.input_size
        self.hidden_size = self.hparams.hidden_size
        self.dp = self.hparams.dropout
        self.hidden_layers = self.hparams.hidden_layers
        self.learning_rate = self.hparams.lr
        self.blocks = []
        for l in range(self.hidden_layers-1):
            self.blocks.append(nn.Linear(self.input_size, self.hidden_size))
            self.blocks.append(nn.ReLU())
            self.blocks.append(nn.Dropout(p = self.dp))
            self.input_size = self.hidden_size
        # reset
        self.input_size = self.hparams.input_size

        self.encoder = nn.Sequential(*self.blocks)
        self.decoder = nn.Linear(self.hidden_size, self.input_size)
        self.encoder.apply(self._init_weights)
        self.decoder.apply(self._init_weights)
        self.save_hyperparameters()

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            m.bias.data.fill_(0.01)
    
    def forward(self, embedding):
        y_hat = self.decoder(self.encoder(embedding))
        return y_hat

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate, weight_decay=1e-5)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.decoder(self.encoder(x))
        loss = F.mse_loss(y_hat, y)
        self.log('train_loss', loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.decoder(self.encoder(x))
        loss = F.mse_loss(y_hat, y)
        self.log('val_loss', loss, on_epoch=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.decoder(self.encoder(x))
        loss = F.mse_loss(y_hat, y)
        self.log('test_loss', loss, on_epoch=True)

    @staticmethod
    def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
        parser = parent_parser.add_argument_group("NonLinearApproximator")
        parser.add_argument(
            '--input_size',
            type=int,
            default=768,
            help="Input embedding size (defaults to 768)."
        )
        parser.add_argument(
            '--hidden_layers',
            type=int, 
            default=2,
            help="Number of hidden layers (defaults to 2)."
        )
        parser.add_argument(
            '--hidden_size',
            type=int,
            default=512,
            help="Hidden Size (defaults to 512)."
        )
        parser.add_argument(
            '--dropout', 
            type=float, 
            default=0.5,
            help="Dropout probability."
        )
        parser.add_argument(
            '--lr', 
            type=float, 
            default=1e-3,
            help="Learning rate for approximation."
        )
        return parent_parser


class LinearApproximator(pl.LightningModule):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.encoder = nn.Linear(input_size, input_size, bias = False)

    def forward(self, embedding):
        y_hat = self.encoder(embedding)
        return y_hat

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = 1e-3, weight_decay=1e-5)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.encoder(x)
        loss = F.mse_loss(y_hat, y)
        self.log('train_loss', loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.encoder(x)
        loss = F.mse_loss(y_hat, y)
        self.log('val_loss', loss, on_epoch=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.encoder(x)
        loss = F.mse_loss(y_hat, y)
        self.log('test_loss', loss, on_epoch=True)