
import pytorch_lightning as pl
import torch.nn as nn
import torch
from datetime import datetime
from os.path import dirname, abspath, join, exists
from os import makedirs
import yaml


class ModelClassifier(pl.LightningModule):
    def __init__(self,model: nn.Module, lr: float, momentum: float):
        super().__init__()
        self.model = model
        self.loss_fn = nn.CrossEntropyLoss()
        self.lr = lr
        self.momentum = momentum
        self.save_hyperparameters(ignore=["loss_fn","model"])

    
    def forward(self, data_input):
        model_output = self.model(data_input)
        # Handle both old format (single output) and new format (output, pre_activations)
        if isinstance(model_output, tuple):
            return model_output[0]  # Return only the main output, ignore pre_activations
        else:
            return model_output
    
    
    def _step(self, batch):
        X, trg = batch
        out = self.forward(data_input=X)
        pred = torch.argmax(torch.softmax(out, dim=-1), dim=-1)
        loss = self.loss_fn(out, trg)
        return loss, pred, trg
    
    
    def training_step(self,batch,*args, **kwargs):
        loss,_, _ = self._step(batch=batch)
        self.log("train_loss",loss,prog_bar=True, on_step=False, on_epoch=True)
        return loss
    
    
    def validation_step(self,batch,*args, **kwargs):
        loss,_,_ = self._step(batch=batch)
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss
    
    
    def test_step(self,batch,*args, **kwargs):
        loss, pred, trg = self._step(batch=batch)
        test_accuracy = (pred==trg).sum()/pred.size(0) 
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log("test_accuracy", test_accuracy, prog_bar=True, on_step=False, on_epoch=True)
        return {"test_loss": loss, "test_accuracy": test_accuracy}
    
    
    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=self.lr, momentum=self.momentum)
    





def mk_fname(filename: str,label: str,suffix: str):
    now = datetime.now()
    timestamp = now.strftime("%Y%m%d_%H%M%S") # format YYYYMMDD_HHMMSS
    return filename+"_"+str(label)+f"_{timestamp}"+suffix

def mk_missing_folders(folders):
    if type(folders) is not list:
        folders = [folders]
    for folder in folders:
        if not exists(folder):
            makedirs(folder)
            print(f"Created folder: {folder}")
