from typing import Callable, Literal

import torch
from torch import nn
import os

import pandas as pd

OUTPUT_DIR = './output/'

class Model(torch.nn.Module):
    def __init__(
        self, 
        num_classes: int
    ):
        super(Model, self).__init__()
        self.num_classes = num_classes

        self.name: str = self.__class__.__name__
        self.model: torch.nn.Module = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.params = None

    def forward(self, x):
        return self.model(x)

    def mkfolders(
        self, 
        model_dir: str, 
        model_name: str
    ) -> None:
        '''
        Creates the output and log folders for the model.

        Args:
            model_dir (str): Directory to save the model.
            model_name (str): Name of the model.
        '''
        if not os.path.exists(f'{model_dir}{OUTPUT_DIR}'):
            os.makedirs(f'{model_dir}{OUTPUT_DIR}')

        if not os.path.exists(f'{model_dir}{OUTPUT_DIR}{model_name}'):
            os.makedirs(f'{model_dir}{OUTPUT_DIR}{model_name}')

    def save(
        self, 
        config, 
        model_dir: str, 
        model_name: str,
        df: pd.DataFrame,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler._LRScheduler | None
    ) -> None:
        '''
        Saves the model and metrics to disk.

        Args:
            config (Config): Config object.
            model_dir (str): Directory to save the model.
            model_name (str): Name of the model.
        '''

        checkpoint = {
            "model": self.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler
        }
        self.remove_checkpoints(model_dir, model_name)

        self._save(checkpoint, f"{model_dir}{OUTPUT_DIR}{model_name}/{model_name}")

        df.to_csv(f'{model_dir}{OUTPUT_DIR}{model_name}/{model_name}.csv', index=False)
        config.save(f'{model_dir}{OUTPUT_DIR}{model_name}/{model_name}.json')
        if config.is_snn:
            self.save_params(f'{model_dir}{OUTPUT_DIR}{model_name}/params.json')
            

    def remove_checkpoints(self, model_dir: str, model_name: str) -> None:
        '''
        Removes the checkpoint file from disk.

        Args:
            model_dir (str): Directory to save the model.
            model_name (str): Name of the model.
        '''
        try:
            for file in os.listdir(f'{model_dir}{OUTPUT_DIR}{model_name}'):
                if file.endswith('.pth') and 'checkpoint' in file:
                    os.remove(f'{model_dir}{OUTPUT_DIR}{model_name}/{file}')
        except FileNotFoundError:
            print('No checkpoints found.')

    def save_checkpoint(
        self,
        model_dir: str,
        model_name: str,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler._LRScheduler | None,
        epoch: int,
        df: pd.DataFrame
    ) -> None:
        checkpoint = {
            "model": self.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler
        }

        self._save(checkpoint, f"{model_dir}{OUTPUT_DIR}{model_name}/{model_name}_checkpoint_{epoch}")
        df.to_csv(f'{model_dir}{OUTPUT_DIR}{model_name}/{model_name}.csv', index=False)

    def _save(
        self,
        checkpoint: dict[str, nn.Module],
        path: str
    ) -> None:
        torch.save(checkpoint, f"{path}.pth")

    def load(
        self, 
        model_dir: str, 
        model_name: str,
        optimizer: torch.optim.Optimizer | None,
        scheduler: torch.optim.lr_scheduler._LRScheduler | None,
        resume: bool = False
    ) -> tuple[torch.optim.lr_scheduler._LRScheduler | None, int, pd.DataFrame | None]:
        '''
        Loads a model from disk.

        Args:
            model_dir (str): Directory to save the model.
            model_name (str): Name of the model.
            optimizer (torch.optim.Optimizer): Optimizer.
            scheduler (torch.optim.lr_scheduler._LRScheduler): Scheduler.
        '''
        path = f"{model_dir}{OUTPUT_DIR}{model_name}/{model_name}"

        checkpoint = self._load(path)
        self.load_state_dict(checkpoint['model'])

        if optimizer is not None and resume:
            optimizer.load_state_dict(checkpoint['optimizer'])

        if scheduler is not None and resume:
            scheduler = checkpoint['scheduler']
            scheduler.last_epoch = checkpoint['scheduler'].last_epoch
            if optimizer is None:
                raise ValueError('Optimizer must be provided if scheduler is provided.')
            scheduler.optimizer = optimizer

        start_epoch: int = 0
        df: pd.DataFrame | None = None
        if os.path.exists(f'{model_dir}{OUTPUT_DIR}{model_name}/{model_name}.csv') and resume:
            df = pd.read_csv(f'{model_dir}{OUTPUT_DIR}{model_name}/{model_name}.csv')
            start_epoch = df['Epoch'].iloc[-1]

        return scheduler, start_epoch, df

    def load_by_path(
        self, 
        path: str,
        optimizer: torch.optim.Optimizer | None,
        scheduler: torch.optim.lr_scheduler._LRScheduler | None
    ) -> torch.optim.lr_scheduler._LRScheduler | None:
        '''
        Loads a model from disk via an absolute path to the weight file.

        Args:
            config (Config): Config object.
            model_name (str): Name of the model.
        '''
        checkpoint = self._load(path)

        self.load_state_dict(checkpoint['model'])

        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer'])

        if scheduler is not None:
            scheduler = checkpoint['scheduler']

        return scheduler

    def _load(self, path: str) -> dict[str, nn.Module]:
        '''
        Loads a module from disk via a relative path to the weight file.

        Args:
            path (str): Path to the weight file.
        '''
        try:
            checkpoint = torch.load(f"{path}.pth", weights_only=False)
            print(f'Loaded from {path}.pth')
            return checkpoint
        except Exception as e:
            print(f'Could not find {path}.pth')
            print(e)
            exit(1)