from dataclasses import dataclass
from typing import Optional, Any, Dict, List, Tuple
import numpy as np
import torch
@dataclass
class EmbeddingPair:
    corpus_emb_1: np.ndarray
    corpus_emb_2: np.ndarray
    query_emb_1: np.ndarray
    query_emb_2: np.ndarray
    p_index_list: List[Tuple[int, int]] = None
    def __post_init__(self):
        for name, emb in self.__dict__.items():
            if name == "p_index_list":
                continue
            if not isinstance(emb, np.ndarray):
                if isinstance(emb, torch.Tensor):
                    setattr(self, name, emb.detach().cpu().numpy())
                else:
                    raise TypeError(f"{name} must be a numpy array or torch tensor, got {type(emb)}")
        if self.corpus_emb_1.shape[0] != self.corpus_emb_2.shape[0]:
            print(f"Warning:Corpus embedding counts don't match: {self.corpus_emb_1.shape[0]} vs {self.corpus_emb_2.shape[0]}")
        if self.query_emb_1.shape[0] != self.query_emb_2.shape[0]:
            print(f"Warning: Query embedding counts don't match: {self.query_emb_1.shape[0]} vs {self.query_emb_2.shape[0]}")
    def pad_to_max_dim(self) -> 'EmbeddingPair':
        max_dim = max(
            self.corpus_emb_1.shape[1],
            self.corpus_emb_2.shape[1],
            self.query_emb_1.shape[1],
            self.query_emb_2.shape[1]
        )
        def pad_if_needed(emb: np.ndarray) -> np.ndarray:
            if emb.shape[1] < max_dim:
                return np.pad(emb, ((0, 0), (0, max_dim - emb.shape[1])))
            return emb
        return EmbeddingPair(
            corpus_emb_1=pad_if_needed(self.corpus_emb_1),
            corpus_emb_2=pad_if_needed(self.corpus_emb_2),
            query_emb_1=pad_if_needed(self.query_emb_1),
            query_emb_2=pad_if_needed(self.query_emb_2)
        )
    @property
    def corpus_size(self) -> int:
        return self.corpus_emb_1.shape[0]
    @property
    def query_size(self) -> int:
        return self.query_emb_1.shape[0]
    @property
    def dim_1(self) -> int:
        return self.corpus_emb_1.shape[1]
    @property
    def dim_2(self) -> int:
        return self.corpus_emb_2.shape[1]
    def to_torch(self, device: Optional[torch.device] = None) -> 'EmbeddingPair':
        device = device or torch.device('cpu')
        return EmbeddingPair(
            corpus_emb_1=torch.from_numpy(self.corpus_emb_1).to(device),
            corpus_emb_2=torch.from_numpy(self.corpus_emb_2).to(device),
            query_emb_1=torch.from_numpy(self.query_emb_1).to(device),
            query_emb_2=torch.from_numpy(self.query_emb_2).to(device)
        )
    def get_stats(self) -> Dict[str, Any]:
        return {
            'corpus_size': self.corpus_size,
            'query_size': self.query_size,
            'dim_1': self.dim_1,
            'dim_2': self.dim_2,
            'corpus_1_mean': float(np.mean(self.corpus_emb_1)),
            'corpus_2_mean': float(np.mean(self.corpus_emb_2)),
            'query_1_mean': float(np.mean(self.query_emb_1)),
            'query_2_mean': float(np.mean(self.query_emb_2)),
            'corpus_1_std': float(np.std(self.corpus_emb_1)),
            'corpus_2_std': float(np.std(self.corpus_emb_2)),
            'query_1_std': float(np.std(self.query_emb_1)),
            'query_2_std': float(np.std(self.query_emb_2))
        }
@dataclass
class EmbeddingDataset:
    pair: EmbeddingPair
    metadata: Optional[Any] = None
    p_index_list: List[Tuple[int, int]] = None
    def pad_to_max_dim(self) -> 'EmbeddingDataset':
        padded_pair = self.pair.pad_to_max_dim()
        max_dim = max(
            padded_pair.dim_1,
            padded_pair.dim_2
        )
        def ensure_dim(pair: EmbeddingPair) -> EmbeddingPair:
            if pair.dim_1 < max_dim or pair.dim_2 < max_dim:
                return pair.pad_to_max_dim()
            return pair
        return EmbeddingDataset(
            pair=ensure_dim(padded_pair),
            metadata=self.metadata,
            p_index_list=self.p_index_list
        )
    def get_stats(self) -> Dict[str, Any]:
        train_stats = self.train.get_stats()
        test_stats = self.test.get_stats()
        return {
            'train': train_stats,
            'test': test_stats
        }
    def to_torch(self, device: Optional[torch.device] = None) -> 'EmbeddingDataset':
        return EmbeddingDataset(
            train=self.train.to_torch(device),
            test=self.test.to_torch(device),
            metadata=self.metadata
        )
