import copy
import torch
import torch.nn.functional as F
import lightning as l
from torchmetrics import Accuracy, F1Score, MetricCollection # Import MetricCollection
from argparse import Namespace # For load_from_checkpoint example

class BaseModel(l.LightningModule):
    """
    A base class using MetricCollection without _common_step.
    """
    def __init__(self, config):
        super().__init__()
        # Use config directly if it's Namespace or dict-like
        # Ensure all expected hparams are in config for save_hyperparameters
        self.save_hyperparameters(config.__dict__)


        # --- MODIFIED: Conditional Metric Initialization ---
        # Initialize collections first
        self.train_metrics = MetricCollection({}, prefix="metric/train/")
        self.val_metrics = MetricCollection({}, prefix="metric/val/")
        self.test_metrics = MetricCollection({}, prefix="metric/test/")

        # Conditionally add classification metrics if num_classes is valid
        num_classes = getattr(self.hparams, "num_classes", 0) # Default to 0 if not present
        if num_classes and num_classes > 0:
            print(f"Initializing classification metrics for {num_classes} classes.")
            classification_metrics = {
                'Accuracy': Accuracy(task="multiclass", num_classes=num_classes),
                'F1Score': F1Score(task="multiclass", num_classes=num_classes, average="macro")
            }
            # Add to each collection
            self.train_metrics.add_metrics(classification_metrics)
            self.val_metrics.add_metrics(copy.deepcopy(classification_metrics))
            self.test_metrics.add_metrics(copy.deepcopy(classification_metrics))
        else:
            print("Skipping classification metric initialization (num_classes <= 0 or not provided).")

        # Validate required config parameters (like lr)
        if not hasattr(self.hparams, "lr"):
             raise AttributeError("Config object must have a 'lr' attribute for the optimizer.")

    @classmethod
    def load_from_checkpoint(cls, checkpoint_path, map_location=None, **kwargs):
        """Loads model from checkpoint, ensuring config is correctly passed."""
        checkpoint = torch.load(
            checkpoint_path, map_location=map_location, weights_only=False
        )
        hparams = checkpoint["hyper_parameters"]

        # Allow overriding hparams from kwargs if needed for fine-tuning etc.
        final_hparams = {**hparams, **kwargs}

        config = Namespace(**final_hparams)

        model = cls(config) # Pass the reconstructed config
        model.load_state_dict(checkpoint["state_dict"])
        print(f"Loaded model from {checkpoint_path}")
        # Return config might be useful for inspection or resuming trainer state
        return model, config

    def forward(self, x):
        """Subclasses must implement the forward pass."""
        raise NotImplementedError("This method should be overridden by subclasses")

    # --- Step Functions (No _common_step) ---

    def training_step(self, batch, batch_idx):
        x, y = batch # Assuming batch is (features, labels) for the default case
        logits = self(x) # Assumes forward takes only x by default
        loss = F.cross_entropy(logits, y)

        # Update training metrics using the collection
        self.train_metrics.update(logits, y) # MetricCollection handles individual metric updates

        # Log loss with standard convention
        self.log("loss/train", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

        # Log clean key for ModelCheckpoint (only needed at epoch end, but log here for simplicity)
        # logger=False prevents it from going to TensorBoard/WandB etc. if already logging loss/train
        self.log("loss_train_monitor", loss, on_step=False, on_epoch=True, logger=False, sync_dist=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x) # Assumes forward takes only x by default
        loss = F.cross_entropy(logits, y)

        # Update validation metrics using the collection
        self.val_metrics.update(logits, y)

        # Log validation loss with standard convention
        self.log("loss/val", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x) # Assumes forward takes only x by default
        loss = F.cross_entropy(logits, y)

        # Update test metrics using the collection
        self.test_metrics.update(logits, y)

        # Log test loss with standard convention
        self.log("loss/test", loss, on_step=False, on_epoch=True, logger=True, sync_dist=True)
    # --- Epoch End Callbacks for Logging Computed Metrics ---

    def on_train_epoch_end(self):
        # Compute all metrics in the collection
        # Keys will have the prefix, e.g., 'metric/train/Accuracy', 'metric/train/F1Score'
        train_metric_dict = self.train_metrics.compute()

        # Log the entire dictionary - keys already have the desired format
        self.log_dict(train_metric_dict, on_epoch=True, logger=True, sync_dist=True)

        # Log F1 separately to prog bar if desired
        f1_key = 'metric/train/F1Score'
        if f1_key in train_metric_dict:
             # Using a distinct key like 'train_f1_prog' ensures it only shows in prog bar
             self.log('train_f1_prog', train_metric_dict[f1_key], prog_bar=True, logger=False, sync_dist=True)

    def on_validation_epoch_end(self):
        # Compute all metrics in the collection
        val_metric_dict = self.val_metrics.compute() # e.g., {'metric/val/Accuracy': ..., 'metric/val/F1Score': ...}

        # Log the entire dictionary
        self.log_dict(val_metric_dict, on_epoch=True, logger=True, sync_dist=True)

        # Log F1 separately to prog bar
        f1_key = 'metric/val/F1Score' # Adjust if needed
        if f1_key in val_metric_dict:
            self.log("metric_val_f1", val_metric_dict[f1_key], prog_bar=True, on_epoch=True, logger=False, sync_dist=True)

    def on_test_epoch_end(self):
        # Compute all metrics in the collection
        test_metric_dict = self.test_metrics.compute() # e.g., {'metric/test/Accuracy': ..., 'metric/test/F1Score': ...}

        # Log the entire dictionary
        self.log_dict(test_metric_dict, on_epoch=True, logger=True, sync_dist=True)


    def configure_optimizers(self):
        """Configures the optimizer (and optionally LR scheduler)."""
        # Access hyperparameters via self.hparams
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=getattr(self.hparams, 'weight_decay', 1e-5) # Use getattr for optional wd
        )

        return optimizer