import torch
import torch.nn.functional as F
import torch.nn as nn

import numpy as np

import pytorch_lightning as pl
from pytorch_lightning import Trainer
import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
from torch.nn import functional as F
from argparse import ArgumentParser
import os
from torch.optim.lr_scheduler import StepLR

class Encoder(torch.nn.Module):
    # Encoder, compresses data of dimension input_size to dimension latent_size.
    def __init__(self, input_size, hidden_size, latent_size):
        super(Encoder, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, latent_size)

    def forward(self, x):  # x: bs, input_size
        x = F.relu(self.linear1(x))  # -> bs, hidden_size
        x = self.linear2(x)  # -> bs, latent_size
        return x

class Decoder(torch.nn.Module):
    # Decoder, converts compressed data of dimension latent_size to data of dimension output_size.
    def __init__(self, latent_size, hidden_size, output_size):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(latent_size, hidden_size)
        self.linear2 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):  # x: bs, latent_size
        x = F.relu(self.linear1(x))  # -> bs, hidden_size
        x = self.linear2(x)  # -> bs, output_size
        return x


class LitAutoEncoder(pl.LightningModule):
    def __init__(self, input_size, output_size, latent_size, hidden_size, learning_rate):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = Encoder(input_size, hidden_size, latent_size)
        self.decoder = Decoder(latent_size, hidden_size, output_size)
        self.loss_function = nn.MSELoss(reduction='sum')

    def forward(self, x):
        h = self.encoder(x)
        re_x = self.decoder(h)
        return h, re_x

    def training_step(self, batch, batch_idx=None):
        x = batch.view(batch.size(0), -1)
        _, x_hat = self.forward(x)
        loss = self.loss_function(x_hat, x)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
        return [optimizer], [scheduler]


def estimate_alpha(feature, cluster_center, std=3):
    x = torch.pairwise_distance(feature.unsqueeze(1), cluster_center.unsqueeze(0), p=2)
    x = torch.log(x)
    mean = torch.mean(x, dim=0)
    stdv = torch.std(x, dim=0)
    alpha = torch.exp(mean - stdv * std)
    return alpha.detach()

