"""
Base classes for TensorGalerkin trainers
"""

import os
import time
import torch
import numpy as np
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Any

from ..utils import manual_seed, save_ckpt, load_ckpt


class TrainerBase(ABC):
    """
    Base class for all trainers, define the init_dataset, steps
    multisteps, compute_loss, test for corresponding trainers.
    
    Attributes:
    -----------
        device: torch.device
            Device to run the training on
        config: object
            Configuration object containing training parameters
        model: torch.nn.Module
            Neural network model
        optimizer: optimizer object
            Optimizer for training
    """
    
    def __init__(self, config):
        """
        Initialize trainer with configuration
        
        Parameters:
        -----------
            config: object
                Configuration object with training parameters
        """
        self.device = config.device
        self.config = config
        manual_seed(config.seed)
        
        # Initialize components
        self.init_dataset()
        self.model = self.init_model()
        self.init_optimizer()
        
        # Move model and data to device
        self.to(self.device)
        
        # Model statistics
        nparams = sum(p.numel() for p in self.model.parameters())
        nbytes = sum(p.numel() * p.element_size() for p in self.model.parameters())
        
        if hasattr(config, 'datarow'):
            config.datarow['nparams'] = nparams
            config.datarow['nbytes'] = nbytes
    
    @abstractmethod
    def init_dataset(self) -> None:
        """Initialize dataset - must be implemented by subclasses"""
        raise NotImplementedError()
    
    @abstractmethod
    def init_model(self) -> torch.nn.Module:
        """Initialize model - must be implemented by subclasses"""
        raise NotImplementedError()
    
    @abstractmethod
    def init_optimizer(self) -> None:
        """Initialize optimizer - must be implemented by subclasses"""
        raise NotImplementedError()
    
    def to(self, device: torch.device) -> None:
        """Move model and data to device"""
        if hasattr(self, 'graph'):
            self.graph.to(device)
        if hasattr(self, 'ndata'):
            self.ndata = self.ndata.to(device)
        if hasattr(self, 'vdata'):
            self.vdata = self.vdata.to(device)
        if hasattr(self, 'vlabel'):
            self.vlabel = self.vlabel.to(device)
        self.model.to(device)
        if hasattr(self, 'equation'):
            self.equation.to(device)
    
    def type(self, dtype: torch.dtype) -> None:
        """Convert model and data to specified dtype"""
        if hasattr(self, 'ndata'):
            self.ndata = self.ndata.type(dtype)
        self.model.type(dtype)
        if hasattr(self, 'ndata_test'):
            self.ndata_test = self.ndata_test.type(dtype)
        if hasattr(self, 'label_test'):
            self.label_test = self.label_test.type(dtype)
    
    @abstractmethod
    def step(self, U_t1: torch.Tensor) -> torch.Tensor:
        """
        Predict the next time step given the current time step.
        
        Parameters:
        -----------
            U_t1: torch.Tensor
                The current time step
                
        Returns:
        --------
            torch.Tensor
                The next time step
        """
        raise NotImplementedError()
    
    @abstractmethod
    def multisteps(self, U_t0: torch.Tensor, steps: int) -> torch.Tensor:
        """
        Predict multiple time steps given the initial time step.
        
        Parameters:
        -----------
            U_t0: torch.Tensor
                The initial time step
            steps: int
                Number of steps to predict
                
        Returns:
        --------
            torch.Tensor
                The predicted time steps
        """
        raise NotImplementedError()
    
    @abstractmethod
    def compute_loss(self, start_idx: Optional[int] = None, 
                    end_idx: Optional[int] = None) -> torch.Tensor:
        """
        Compute loss for training
        
        Parameters:
        -----------
            start_idx: Optional[int]
                Start index for batch
            end_idx: Optional[int]
                End index for batch
                
        Returns:
        --------
            torch.Tensor
                Computed loss
        """
        raise NotImplementedError()
    
    @abstractmethod
    def fit(self, verbose: bool = False) -> None:
        """Train the model"""
        raise NotImplementedError()
    
    @abstractmethod
    def test(self) -> None:
        """Test the model"""
        raise NotImplementedError()
    
    def load_ckpt(self) -> 'TrainerBase':
        """Load checkpoint from config.ckpt_path"""
        load_ckpt(self.config.ckpt_path, model=self.model)
        return self
    
    def save_ckpt(self) -> 'TrainerBase':
        """Save checkpoint to config.ckpt_path"""
        os.makedirs(os.path.dirname(self.config.ckpt_path), exist_ok=True)
        save_ckpt(self.config.ckpt_path, model=self.model)
        return self