"""Utility functions for the ACEv2 model."""

from dataclasses import dataclass
from typing import Optional

import torch


@dataclass
class DataAttr:
    """Dataclass for neural process data with all six components."""

    xc: Optional[torch.Tensor] = None
    yc: Optional[torch.Tensor] = None
    xb: Optional[torch.Tensor] = None
    yb: Optional[torch.Tensor] = None
    xt: Optional[torch.Tensor] = None
    yt: Optional[torch.Tensor] = None
    
    def to(self, device):
        """Move all tensors to the specified device."""
        return DataAttr(
            xc=self.xc.to(device) if self.xc is not None else None,
            yc=self.yc.to(device) if self.yc is not None else None,
            xb=self.xb.to(device) if self.xb is not None else None,
            yb=self.yb.to(device) if self.yb is not None else None,
            xt=self.xt.to(device) if self.xt is not None else None,
            yt=self.yt.to(device) if self.yt is not None else None,
        )


class LossAttr(dict):
    """Dictionary with attribute-style access for loss computation results."""

    def __init__(self, *args, **kwargs):
        defaults = {
            "loss": None,
            "log_likelihood": None,
            "means": None,
            "sds": None,
            "weights": None,
        }
        # Initialize with defaults first
        super().__init__(defaults)
        # Then update with any user-supplied dictionary or keyword arguments
        self.update(*args, **kwargs)

    def __getattr__(self, key):
        return self[key]

    def __setattr__(self, key, value):
        self[key] = value

    def __delattr__(self, key):
        try:
            del self[key]
        except KeyError:
            msg = f"'LossAttr' object has no attribute '{key}'"
            raise AttributeError(msg) from Noneß


def create_context_buffer_datapoint(query: DataAttr, yhat: torch.Tensor) -> DataAttr:
    """Create a buffer datapoint for autoregressive inference.

    For inference, we need both context (for updating embeddings) and buffer
    (for autoregressive embedder) to contain the same data.
    """
    return DataAttr(
        xc=query.xt,  # target x becomes context x
        yc=yhat,  # prediction becomes context y
        xb=query.xt,  # also store as buffer for AR embedder
        yb=yhat,  # also store as buffer for AR embedder
        xt=None,
        yt=None,
    )


def fetch_next_query(batch: DataAttr, k: int) -> DataAttr:
    """Fetch the k-th query from the batch (target points!)"""
    return DataAttr(
        xc=batch.xc,
        yc=batch.yc,
        xb=None,
        yb=None,
        xt=batch.xt[:, k : k + 1, :],  # single target point at index k
        yt=batch.yt[:, k : k + 1, :] if batch.yt is not None else None,
    )


def concatenate_batches(batch1: DataAttr, batch2: DataAttr) -> DataAttr:
    """Concatenate two batches of DataAttr.

    In practice, this is only used for concatenating context predictions,
    so we only need to handle xc and yc fields.
    """
    # For context predictions, only xc and yc are non-None
    xc = (
        torch.cat([batch1.xc, batch2.xc], dim=1)
        if batch1.xc is not None and batch2.xc is not None
        else None
    )
    yc = (
        torch.cat([batch1.yc, batch2.yc], dim=1)
        if batch1.yc is not None and batch2.yc is not None
        else None
    )

    return DataAttr(
        xc=xc,
        yc=yc,
        xb=None,
        yb=None,
        xt=None,
        yt=None,
    )


def fetch_next_query_batch(batch: DataAttr, start_idx: int, K: int) -> DataAttr:
    """Fetch the next batch of K queries starting at start_idx"""
    end_idx = start_idx + K

    return DataAttr(
        xc=batch.xc,
        yc=batch.yc,
        xb=None,
        yb=None,
        xt=batch.xt[:, start_idx:end_idx, :],  # K target points
        yt=batch.yt[:, start_idx:end_idx, :] if batch.yt is not None else None,
    )
