from pathlib import Path
import os
from typing import Final

import torch
from torch import nn
from torch_geometric.data import Data, Batch

from src.models.gnn_interface import InterfaceGNN
from src.models.mappings.model_class_mapping import MODEL_NAME_MAPPING
from src.utils.path_io import get_path_up_to

MODEL_OUTPUT_PATH: Final[Path] = Path("data", "output", "trained_models")
ROOT_DIR: Final[str] = get_path_up_to(os.path.abspath(__file__), "src")

class GraphGnnWrapper:

    model: InterfaceGNN
    device: torch.device
    name: str
    model_class: str
    model_attributes: dict
    path: str | None

    def __init__(self, model_class: str, run_name:str, **kwargs):
        self.model_class = model_class
        self.name = run_name
        self.model_attributes = kwargs
        self.model = None
        self.training = True

        if "model_path" in self.model_attributes.keys():
            file_path = self.model_attributes.pop("model_path")
            self.path = ROOT_DIR + file_path
        else:
            self.path = None


    def calc_batch(self, batch: Batch) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Calculates the predictions and targets for a batch of data.

        This method calculates the predictions and targets for a batch of data. The method is called during each epoch
        for the training and validation data. The method returns the predictions and targets for the batch.

        Args:
            batch (DataLoader): DataLoader for the training and validation graphs.

        Returns:
            torch.Tensor: The predictions for the batch.
            torch.Tensor: The targets for the batch.
        """
        batch = batch.to(self.device)

        input_graph_feature = batch.x
        input_graph_edge_index = batch.edge_index
        input_graph_attr = batch.edge_attributes if hasattr(batch, 'edge_attributes') else None
        input_batch = batch.batch

        target = batch.y

        output = self.model.forward(x=input_graph_feature,
                                        edge_index=input_graph_edge_index,
                                        edge_attr=input_graph_attr,
                                        batch=input_batch)

        if self.training:
            output = (output, target)
        elif not self.training:
            output = (output[0], target, output[1]) if isinstance(output, tuple) else (output, target, None)

        return output

        # TODO: Move target transformation here
        return prediction, target

    def create_model(self) -> nn.Module:
        """
        Creates a new model instance based on the model name and parameters.

        Args:
            model_name (str): The name of the model to create.
            **kwargs: Additional parameters for the model.

        Returns:
            nn.Module: The created model instance.
        """

        last_activation_function = self.model_attributes.pop('last_activation')
        self.model = MODEL_NAME_MAPPING[self.model_class](final_activation=last_activation_function, **self.model_attributes)
        if self.path is not None:
            self.model.load_state_dict(torch.load(self.path))

        return self.model

    def save_model(self) -> str:
        """
        Saves the specified PyTorch model to a file in the `MODEL_OUTPUT_PATH` directory.

        The method first gets the class name of the model and the latest version number
        of this model class. If no previous versions are found, it sets the version number to 1.

        The model is then saved to a file with a name in the format '{model_class_name}_v{version}.pt'.
        The absolute path to the saved model file is returned.

        Args:
            model (nn.Module): The PyTorch model to be saved.

        Returns:
            str: The absolute path to the saved model file.
        """
        # Use the global variable to determine if this is the first run of the
        # program.

        self.path = Path(ROOT_DIR, MODEL_OUTPUT_PATH, f'{self.name}.pt')

        # Save the model state dict to the specified path
        torch.save(self.model.state_dict(), self.path)
        return str(self.path.absolute())

    def load_model(self) -> nn.Module:
        """
        Loads a PyTorch model from the specified file.

        The method loads the model from the specified file and sets the model to evaluation mode.

        Args:
            path (str): The path to the file containing the model.

        Returns:
            nn.Module: The loaded model.
        """

        state_dict = torch.load(self.path)
        model = MODEL_NAME_MAPPING[self.model_class](**self.model_attributes)
        model.load_state_dict(state_dict)
        model.eval()

        return model

    def to(self,device) -> None:
        """
        Moves the model to the specified device.

        Args:
            device (torch.device): The device to move the model to (e.g., 'cpu' or 'cuda').
        """
        self.model.to(device)
        self.device = device

    def train(self):
        """
        Sets the model to training mode.

        This method sets the model to training mode, which enables dropout and batch normalization layers.
        """
        self.training = True
        self.model.train()

    def eval(self):
        """
        Sets the model to evaluation mode.

        This method sets the model to evaluation mode, which disables dropout and batch normalization layers.
        """
        self.training = False
        self.model.eval()
