# trainer.py
# Minimal training loop for multi-task components.
# This is a skeleton: in practice you will supply dataset loaders, loss definitions and optimization schedules.

from typing import Dict, Any, Optional
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os

def save_checkpoint(state: Dict[str, Any], path: str):
    torch.save(state, path)

def load_checkpoint(path: str, device: Optional[str] = None) -> Dict[str, Any]:
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    return torch.load(path, map_location=device)

class Trainer:
    """
    Trainer for multi-task components (AffectNet, NeedClassifier, Router, etc.)
    You must provide:
      - model_dict: dict of modules to train (e.g., {'affect': AffectNet(), 'need': NeedClassifier(), ...})
      - dataloaders: dict of DataLoader for train/val
      - loss_fns: dict of loss functions mapping to model names
    """

    def __init__(self, model_dict: Dict[str, nn.Module], dataloaders: Dict[str, DataLoader], lr: float = 1e-4, device: Optional[str] = None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.models = {k: v.to(self.device) for k, v in model_dict.items()}
        # combine parameters
        params = []
        for m in self.models.values():
            params += list(m.parameters())
        self.opt = optim.Adam(params, lr=lr)
        self.dataloaders = dataloaders

    def train_epoch(self, epoch: int, task_losses: Dict[str, Callable] = None):
        """
        Single training epoch over training loader.
        task_losses: optional mapping from model key to a callable loss_fn(batch, model_output)
        """
        self._set_mode(train=True)
        loader = self.dataloaders.get("train")
        if loader is None:
            raise ValueError("train DataLoader missing in dataloaders")
        total_loss = 0.0
        n_batches = 0
        for batch in loader:
            # a simple example: each batch should contain 'text_embed', 'labels_affect', 'labels_need', 'meta'
            # For a real pipeline you'd prepare batches accordingly.
            text_embed = batch.get("text_embed")
            if isinstance(text_embed, list):
                text_embed = torch.stack(text_embed).to(self.device)
            self.opt.zero_grad()
            loss = 0.0
            # run and compute each model's loss if present
            # e.g., AffectNet
            if "affect" in self.models and "labels_affect" in batch:
                y_pred = self.models["affect"](text_embed).squeeze(-1)
                y_true = torch.tensor(batch["labels_affect"], dtype=torch.float32, device=self.device)
                l_aff = nn.BCELoss()(y_pred, y_true)
                loss = loss + l_aff
            if "need" in self.models and "labels_need" in batch:
                logits = self.models["need"](text_embed)
                y_true = torch.tensor(batch["labels_need"], dtype=torch.long, device=self.device)
                l_need = nn.CrossEntropyLoss()(logits, y_true)
                loss = loss + l_need
            # add other tasks losses as needed
            loss.backward()
            self.opt.step()
            total_loss += float(loss.item())
            n_batches += 1
        avg_loss = total_loss / max(1, n_batches)
        print(f"Epoch {epoch}: train loss {avg_loss:.4f}")
        return avg_loss

    def evaluate(self):
        self._set_mode(train=False)
        loader = self.dataloaders.get("val")
        if loader is None:
            return {}
        # implement evaluation loop similar to train_epoch but without gradients
        metrics = {}
        return metrics

    def _set_mode(self, train: bool = True):
        for m in self.models.values():
            if train:
                m.train()
            else:
                m.eval()
