import os
import sys
import torch
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import numpy

"""
Task data management utilities.
"""

# Allow running as a script.
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))


# Accepted tensor-like inputs.
TensorLike = Union[torch.Tensor, "numpy.ndarray", List[float], List[List[float]]]


@dataclass
class TaskData:
    """
    Task data container.
    """

    X: torch.Tensor
    y: torch.Tensor


class TaskManager:
    """
    Task data manager.
    """

    def __init__(
        self, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32
    ):
        """
        Initialize manager.
        """
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device
        self.dtype = dtype

        # History task data.
        self.history_tasks: Dict[str, TaskData] = {}

        # Kendall values by task.
        self.history_kendalls: Dict[str, Optional[float]] = {}

        # Target task data.
        self.target: Optional[TaskData] = None

        # History task models.
        self.history_models: Dict[str, Tuple[object, object]] = {}

    def _to_tensor(self, value: TensorLike) -> torch.Tensor:
        """
        Convert input to torch.Tensor.
        """
        if isinstance(value, torch.Tensor):
            return value.to(self.device, dtype=self.dtype)

        import numpy as np

        if isinstance(value, np.ndarray):
            return torch.from_numpy(value).to(self.device, dtype=self.dtype)

        return torch.tensor(value, device=self.device, dtype=self.dtype)

    def add_history_task(self, name: str, X: TensorLike, y: TensorLike) -> TaskData:
        """
        Add a history task.
        """
        X_tensor = self._to_tensor(X)
        y_tensor = self._to_tensor(y)
        td = TaskData(X=X_tensor, y=y_tensor)
        self.history_tasks[name] = td
        self.history_kendalls[name] = None
        return td

    def list_history_names(self) -> List[str]:
        """
        List history task names.
        """
        return list(self.history_tasks.keys())

    def get_history_task(self, name: str) -> TaskData:
        """
        Get history task data.
        """
        return self.history_tasks[name]

    def set_history_eval_data(self, X: TensorLike, y: TensorLike) -> TaskData:
        """
        Set aggregated history eval data.
        """
        X_tensor = self._to_tensor(X)
        y_tensor = self._to_tensor(y)
        td = TaskData(X=X_tensor, y=y_tensor)
        self.history_eval = td
        return td

    def get_history_eval_data(self) -> TaskData:
        """
        Get aggregated history eval data.
        """
        if self.history_eval is None:
            raise RuntimeError("history eval data is not set")
        return self.history_eval

    def set_target_data(self, X: TensorLike, y: TensorLike) -> TaskData:
        """
        Set target task data.
        """
        X_tensor = self._to_tensor(X)
        y_tensor = self._to_tensor(y)
        td = TaskData(X=X_tensor, y=y_tensor)
        self.target = td
        return td

    def get_target_data(self) -> TaskData:
        """
        Get target task data.
        """
        if self.target is None:
            raise RuntimeError("target data is not set")
        return self.target

    def append_target_samples(self, X_new: TensorLike, y_new: TensorLike) -> TaskData:
        """
        Append samples to target task.
        """
        if self.target is None:
            raise RuntimeError("target data is not set")

        X_new_tensor = self._to_tensor(X_new)
        y_new_tensor = self._to_tensor(y_new)

        # Concatenate on batch dimension.
        X_cat = torch.cat([self.target.X, X_new_tensor], dim=0)
        y_cat = torch.cat([self.target.y, y_new_tensor], dim=0)

        self.target = TaskData(X=X_cat, y=y_cat)
        return self.target

    def set_history_models(
        self, history_name: str, mse_model: object, rank_model: object
    ) -> None:
        """
        Bind models to a history task.
        """
        self.history_models[history_name] = (mse_model, rank_model)

    def get_history_models(
        self, history_name: str
    ) -> Tuple[Optional[object], Optional[object]]:
        """
        Get models for a history task.
        """
        return self.history_models.get(history_name, (None, None))

    def set_history_kendall(self, history_name: str, tau: float) -> None:
        """
        Set Kendall value for a history task.
        """
        if history_name not in self.history_tasks:
            raise KeyError(f"history task '{history_name}' not found")
        self.history_kendalls[history_name] = float(tau)

    def get_history_kendall(self, history_name: str) -> Optional[float]:
        """
        Get Kendall value for a history task.
        """
        if history_name not in self.history_tasks:
            raise KeyError(f"history task '{history_name}' not found")
        return self.history_kendalls.get(history_name)

    def get_all_history_kendalls(self) -> Dict[str, Optional[float]]:
        """
        Get mapping of Kendall values.
        """
        return dict(self.history_kendalls)

    def compute_kendall_weights_on_target(
        self,
        history_mse_models: Dict[str, object],
        device: Optional[torch.device] = None,
    ) -> Tuple[Dict[str, float], torch.Tensor]:
        """
        Compute Kendall weights on target data.
        """
        from utils.Kendall import calculate_kendall_tau

        # Target data.
        target = self.get_target_data()
        X_target = target.X
        y_target = target.y
        if y_target.dim() > 1:
            y_target = y_target.squeeze(-1)

        if device is None:
            device = self.device

        taus_by_name: Dict[str, float] = {}
        taus_list: List[float] = []

        # Preserve history task order.
        for name in self.list_history_names():
            if name not in history_mse_models:
                raise KeyError(
                    f"history_mse_models missing model for history task '{name}'"
                )

            model = history_mse_models[name]
            mean, std, cov = model.predict(X_target)
            tau = calculate_kendall_tau(X_target, y_target, mean, cov, device=device)

            tau_value = float(tau)
            taus_by_name[name] = tau_value
            self.history_kendalls[name] = tau_value
            taus_list.append(tau_value)

        weights_tensor = torch.tensor(taus_list, device=device, dtype=self.dtype)
        return taus_by_name, weights_tensor


if __name__ == "__main__":
    # Simple demo.
    manager = TaskManager()

    # History task 1
    X_h1 = torch.randn(10, 3)
    y_h1 = torch.randn(10, 1)
    manager.add_history_task("history_1", X_h1, y_h1)

    # History task 2
    X_h2 = torch.randn(12, 3)
    y_h2 = torch.randn(12, 1)
    manager.add_history_task("history_2", X_h2, y_h2)

    # Target data
    X_target = torch.randn(8, 3)
    y_target = torch.randn(8, 1)
    manager.set_target_data(X_target, y_target)

    # Append target samples
    X_new = torch.randn(4, 3)
    y_new = torch.randn(4, 1)
    manager.append_target_samples(X_new, y_new)

    print("history tasks:", manager.list_history_names())
    tgt = manager.get_target_data()
    print("target X shape:", tgt.X.shape)
    print("target y shape:", tgt.y.shape)

    # Dummy MSE models
    class DummyMSEModel:
        def __init__(self, bias: float):
            self.bias = bias

        def predict(self, X: torch.Tensor):
            X = X.to(manager.device)
            mean = X.sum(dim=1) + self.bias
            var = torch.full_like(mean, 0.1)
            std = torch.sqrt(var)
            cov = torch.diag(var)
            return mean, std, cov

    mse_models = {
        "history_1": DummyMSEModel(bias=0.0),
        "history_2": DummyMSEModel(bias=1.0),
    }

    manager.set_history_models("history_1", mse_models["history_1"], None)
    manager.set_history_models("history_2", mse_models["history_2"], None)

    mse_models_from_manager = {
        name: manager.get_history_models(name)[0]
        for name in manager.list_history_names()
    }

    taus_by_name, weights = manager.compute_kendall_weights_on_target(
        mse_models_from_manager
    )
    print("taus_by_name:", taus_by_name)
    print("weights tensor:", weights)
