#!/usr/bin/env python
#
# Base experiment class for Flow Matching experiments

import os
import torch
import torch.nn as nn
import numpy as np
from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple
from torch.utils.data import DataLoader
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from torchcfm.conditional_flow_matching import (
    ConditionalFlowMatcher,
    ExactOptimalTransportConditionalFlowMatcher, 
    TargetConditionalFlowMatcher,
    VariancePreservingConditionalFlowMatcher,
)
import sys
sys.path.append('../models')

class BaseExperiment(ABC):
    """
    Abstract base class defining the contract for an experiment.
    
    This class uses the Strategy Pattern, where each concrete experiment
    (e.g., for Kitchen, PushT) provides a specific implementation (strategy)
    for setting up the model, data, and environment. The PyTorch Lightning
    module acts as the context that uses these strategies.
    """
    
    def __init__(self, config):
        """Initialize the experiment with configuration."""
        self.config = config
        self.device = torch.device(config.device)
        
        # These are populated by the concrete experiment class
        self.dataloader = None
        self.model = None

        # Validate configuration if a validate method exists
        if hasattr(config, 'validate'):
            config.validate()
        
    @abstractmethod
    def setup_dataset(self) -> DataLoader:
        """
        Set up the dataset and return a DataLoader.
        This is specific to each experiment.
        """
        pass
    
    @abstractmethod
    def setup_model(self) -> nn.Module:
        """
        Set up the neural network model and return it.
        This is specific to each experiment's architecture.
        """
        pass
    
    @abstractmethod
    def process_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Process a raw batch from the DataLoader into the format required for training.
        Returns a tuple of (observations, actions).
        """
        pass
