import torch
from abc import ABC, abstractmethod
from tqdm import tqdm
from collections import defaultdict

class BaseTrainer(ABC):
    """
    A base class for supervised training. Subclasses should implement the abstract methods.
    """

    def __init__(self, model, optimizer, criterion):
        """
        Initialize the trainer with a model, optimizer, loss function, and device.

        Args:
            model (torch.nn.Module): The model to train.
            optimizer (torch.optim.Optimizer): The optimizer for training.
            criterion (torch.nn.Module): The loss function.
            device (str): The device to use ('cpu' or 'cuda').
        """
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
    
    @abstractmethod
    def train_step(self, batch):
        """
        Perform a single training step.

        Args:
            batch (tuple): A batch of data (inputs, targets).

        Returns:
            float: The loss value for the step.
        """
        pass

    @abstractmethod
    def validate_step(self, batch):
        """
        Perform a single validation step.

        Args:
            batch (tuple): A batch of data (inputs, targets).

        Returns:
            float: The loss value for the step.
        """
        pass
    
    def update(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    def train_epoch(self, train_loader):
        """
        Train the model for one epoch.

        Args:
            train_loader (torch.utils.data.DataLoader): The training data loader.

        Returns:
            float: The average training loss for the epoch.
        """
        if isinstance(self.model, torch.nn.Module):
            self.model.train()
        elif isinstance(self.model, dict):
            for _, model in self.model.items():
                model.train()
                
        total_loss = defaultdict(float)
        for batch in tqdm(train_loader, desc="Training Epoch"):
            loss = self.train_step(batch)
            # print(f"loss: {loss}")
            for key, value in loss.items():
                total_loss[key] += value
        
        for key in total_loss.keys():
            total_loss[key] /= len(train_loader)
        return total_loss["policy_loss"]

    def validate_epoch(self, val_loader):
        """
        Validate the model for one epoch.

        Args:
            val_loader (torch.utils.data.DataLoader): The validation data loader.

        Returns:
            float: The average validation loss for the epoch.
        """
        if isinstance(self.model, torch.nn.Module):
            self.model.eval()
        elif isinstance(self.model, dict):
            for _, model in self.model.items():
                model.eval()
        
        total_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                loss = self.validate_step(batch)
                total_loss += loss["policy_loss"] + loss["v_loss"] + loss["q_loss"] # this is a temporary solution, we could come back later to unify the interface
        return total_loss / len(val_loader)

    @abstractmethod
    def save_checkpoint(self, filepath):
        """
        Save the model checkpoint.

        Args:
            filepath (str): The path to save the checkpoint.
        """
        pass

    @abstractmethod
    def load_checkpoint(self, filepath):
        """
        Load the model checkpoint.

        Args:
            filepath (str): The path to load the checkpoint from.
        """
        pass