import abc
from torch.nn import Module
import torch
import os

MODELS = []

WEIGHT_DIR = os.path.join(os.path.dirname(__file__), "weights")

class RecorderMeta(abc.ABCMeta, type):
    def __new__(cls, name, bases, class_dict):
        new_class = super().__new__(cls, name, bases, class_dict)
        if name != "SDE":
            MODELS.append(name)
        return new_class

class Trained(Module, metaclass = RecorderMeta):
    """Abstract class acting as parent to all trained score models"""
    def __init__(self, name, 
                 randomize_weights = False, 
                 *args, **kwargs):
        self.name = name
        self.trained = False
        super().__init__()

        model_type = self.__class__.__name__
        self._filename = model_type + "." + name + ".pt"

    def load_weights(self, dir = None):
        if dir is None:
            dir = os.path.join(WEIGHT_DIR, self._filename)
        if Trained.validate_dir(dir):
            loaded_dict = torch.load(dir, weights_only= True)
            try:
                self.load_state_dict(loaded_dict)
                self.trained = True
                self.eval()
                print(
                    f"Succesfully loaded weights from {dir}/{self._filename}."
                )
                self.trained = True
            except RuntimeError:
                print(f"Could not load module weights for module {self._filename}. \
                      The stored weights are incompatible with the current module structure. \
                      Did you train with other hyperparameters? \n \
                      Renaming the weights in {dir} to {self._filename}.old")
                os.rename(dir, dir+'.old')
                
    def save_weights(self, dir = None):
        if dir is None:
            dir = os.path.join(WEIGHT_DIR, self._filename)

        if Trained.validate_dir(dir):
            torch.save(self.state_dict(), dir)

    def _init_weights(self, randomize_weights):
        if self._filename in os.listdir(WEIGHT_DIR):
                self.load_weights(os.path.join(WEIGHT_DIR, self._filename))
        else: 
            print(f"Could not find weights {self._filename} in directory {WEIGHT_DIR}.")
            randomize_weights = True

        if randomize_weights:
            print("Initializing weights randomly")
            self.init_weights()

    def get_filename(self):
        return self._filename
    
    # Makes sure weights are loaded after defining module structure
    def __init_subclass__(cls):
        old_init = cls.__init__

        def new_init(self, *args, **kwargs):
            randomize_weights = kwargs.get("randomize_weights", False)
            old_init(self, *args, **kwargs)
            self._init_weights(randomize_weights)           

        cls.__init__ = new_init

        return super().__init_subclass__()
    
    # TODO: Implement this
    @staticmethod
    def validate_dir(dir):
        return True

    @abc.abstractmethod
    def init_weights(self):
        pass