import os
import sys

from contexttimer import Timer

import torch
from torch import nn

import pytorch_lightning as pl

class SiameseModel(pl.LightningModule):
    def __init__(self, base_net, optimizer=None):
        super().__init__()

        self._base_net = base_net
        self._optimizer = optimizer

    def forward(self, batch):
        return self._base_net(batch)

    def training_step(self, batch, batch_idx):
        spectrum_0, spectrum_1, tanimoto = batch
        output_0 = self(spectrum_0)
        output_1 = self(spectrum_1)
        loss = self._loss(output_0, output_1, tanimoto)
        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        spectrum_0, spectrum_1, tanimoto = batch
        output_0 = self(spectrum_0)
        output_1 = self(spectrum_1)
        loss = self._loss(output_0, output_1, tanimoto)
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        spectrum_0, spectrum_1, tanimoto = batch
        output_0 = self(spectrum_0)
        output_1 = self(spectrum_1)
        loss = self._loss(output_0, output_1, tanimoto)
        self.log("test_loss", loss)

    def predict_step(self, batch, batch_idx):
        return self(batch)

    def configure_optimizers(self):
        self._optimizer.instantiate_optimizer(self._base_net.parameters())

        if self._optimizer._scheduler is None:
            return self._optimizer._optimizer
        else:
            return [self._optimizer._optimizer], [self._optimizer._scheduler]

    def _loss(self, output_0, output_1, tanimoto):
        loss = nn.MSELoss()(nn.CosineSimilarity()(output_0, output_1), tanimoto)

        return loss
