from models import ClassificationFeatureModel
from dataset import get_dataloaders

import torch
from torch import nn
import torch.nn.functional as F

import numpy as np

import pytorch_lightning as pl
from pytorch_lightning.trainer.supporters import CombinedLoader

from dgllife.utils import (
    AttentiveFPAtomFeaturizer,
    AttentiveFPBondFeaturizer,
)


from pytorch_lightning.trainer.supporters import CombinedLoader


class LitModuleCombined(pl.LightningModule):
    def __init__(self, params):
        super().__init__()
        
        self.save_hyperparameters(params)
        self.dataloaders = get_dataloaders(
            params["data_type"], params["split_type"], params["batch_size"]
        )

        self.params = params
        self.ood_factor = params["ood_factor"]

        model_config = {}
        node_featurizer = AttentiveFPAtomFeaturizer()
        edge_featurizer = AttentiveFPBondFeaturizer()
        self.params["m_in"] = -35
        self.params["m_out"] = -5

        model_config["num_nodes"] = node_featurizer.feat_size()
        model_config["num_edges"] = edge_featurizer.feat_size()
        model_config["num_feat"] = self.params["embed_size"]
        model_config["n_steps"] = 3
        model_config["num_outputs"] = 2
        model_config["ood_head"] =True

        self.model = ClassificationFeatureModel(model_config)
        self.loss = nn.CrossEntropyLoss()
        

    def predict(self, batch):

        losses = {}
        ood_values = {}
        id_batch, id_labels, id_smiles = batch["id"]
        ood_batch, ood_smiles = batch["ood"]
        m_in = self.params["m_in"]
        m_out = self.params["m_out"]

        id_energies = self.model(id_batch)
        id_loss = self.loss(id_energies, torch.tensor(id_labels))

        ood_energies = self.model(ood_batch)
        Ec_in_max = -torch.logsumexp(id_energies, dim=1)
        # Ec_in_min = torch.logsumexp(-1 * id_energies, dim=1)
        Ec_out = -torch.logsumexp(ood_energies, dim=1)
        ood_loss = (
            torch.pow(F.relu(Ec_in_max - m_in), 2).mean()
            + torch.pow(F.relu(m_out - Ec_out), 2).mean()
        )

        id_dist = Ec_in_max
        ood_dist = Ec_out

        ood_values["id"] = id_dist
        ood_values["ood"] = ood_dist

        loss = id_loss + self.ood_factor * ood_loss
        losses["id"] = id_loss
        losses["ood"] = ood_loss
        losses["loss"] = loss

        return losses, ood_values

    def training_step(self, batch, batch_idx):

        losses, ood_values = self.predict(batch)
        self.log("train_loss", losses["loss"])
        self.log("train_id_loss", losses["id"])
        self.log("train_ood_loss", losses["ood"])

        return losses["loss"]

    def training_epoch_end(self, outputs):
        sch = self.lr_schedulers()
        sch.step()

    def validation_step(self, batch, batch_idx):
        losses, ood_values = self.predict(batch)
        self.log("val_loss", losses["loss"])
        self.log("val_id_loss", losses["id"])
        self.log("val_ood_loss", losses["ood"])
        self.logger.experiment.add_histogram(
            "val_id_values",
            ood_values["id"].detach(),
            self.current_epoch,
        )
        self.logger.experiment.add_histogram(
            "val_ood_values",
            ood_values["ood"].detach(),
            self.current_epoch,
        )
        return losses["loss"]

    def test_step(self, batch, batch_idx):
        losses, ood_values = self.predict(batch)
        self.log("test_loss", losses["loss"])
        self.log("test_id_loss", losses["id"])
        self.log("test_ood_loss", losses["ood"])
        self.logger.experiment.add_histogram(
            "test_id_values",
            ood_values["id"].detach(),
            self.current_epoch,
        )
        self.logger.experiment.add_histogram(
            "test_ood_values",
            ood_values["ood"].detach(),
            self.current_epoch,
        )
        return losses["loss"]

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=1, gamma=0.8
        )
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def train_dataloader(self):
        loaders = {
            "id": self.dataloaders["train_id"],
            "ood": self.dataloaders["train_ood"],
        }
        combined_loaders = CombinedLoader(loaders, mode="min_size")
        return combined_loaders

    def val_dataloader(self):
        loaders = {
            "id": self.dataloaders["valid_id"],
            "ood": self.dataloaders["valid_ood"],
        }
        combined_loaders = CombinedLoader(loaders, mode="min_size")
        return combined_loaders

    def test_dataloader(self):
        loaders = {
            "id": self.dataloaders["test_id"],
            "ood": self.dataloaders["test_ood"],
        }
        combined_loaders = CombinedLoader(loaders, mode="min_size")
        return combined_loaders
