import os
import sys

from contexttimer import Timer

import torch
from torch import nn

import pytorch_lightning as pl

class PropertyPredictionModel(pl.LightningModule):
    def __init__(self, base_net, optimizer=None):
        super().__init__()

        self._base_net = base_net
        self._optimizer = optimizer

    def forward(self, batch):
        return self._base_net(batch)

    def training_step(self, batch, batch_idx):
        spectrum, true_prop = batch
        pred_prop = self(spectrum)
        loss = self._loss(pred_prop, true_prop)
        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        spectrum, true_prop = batch
        pred_prop = self(spectrum)
        loss = self._loss(pred_prop, true_prop)
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        spectrum, true_prop = batch
        pred_prop = self(spectrum)
        loss = self._loss(pred_prop, true_prop)
        self.log("test_loss", loss)

    def predict_step(self, batch, batch_idx):
        return self(batch)

    def configure_optimizers(self):
        self._optimizer.instantiate_optimizer(self._base_net.parameters())

        if self._optimizer._scheduler is None:
            return self._optimizer._optimizer
        else:
            return [self._optimizer._optimizer], [self._optimizer._scheduler]

    def _loss(self, output, properties):
        loss = nn.MSELoss()(output, properties)

        return loss
