import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
import sys
import os


sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from utils.Kendall import calculate_kendall_tau
from loss_function import ListNetLoss, RankCosineLoss, build_loss


class ListRankDataset(Dataset):
    def __init__(
        self, X_pool: torch.Tensor, y_pool: torch.Tensor, list_size: int, n_lists: int
    ):
        if not isinstance(X_pool, torch.Tensor) or not isinstance(y_pool, torch.Tensor):
            raise TypeError("X_pool and y_pool must be torch.Tensor")
        if X_pool.dim() != 2:
            raise ValueError("X_pool must be a 2D tensor of shape [N, D]")
        if y_pool.dim() == 2 and y_pool.size(-1) == 1:
            y_pool = y_pool.view(-1)
        if y_pool.dim() != 1:
            raise ValueError("y_pool must be a 1D tensor of shape [N] or [N, 1]")
        if X_pool.size(0) != y_pool.size(0):
            raise ValueError("X_pool and y_pool must have the same number of rows")
        if list_size <= 0:
            raise ValueError("list_size must be positive")
        if n_lists <= 0:
            raise ValueError("n_lists must be positive")

        self.X_pool = X_pool
        self.y_pool = y_pool
        self.list_size = int(list_size)
        self.n_lists = int(n_lists)

    def __len__(self) -> int:
        return self.n_lists

    def __getitem__(self, index: int):
        n = self.X_pool.size(0)
        if n >= self.list_size:
            idx = torch.randperm(n, device=self.X_pool.device)[: self.list_size]
        else:
            idx = torch.randint(0, n, (self.list_size,), device=self.X_pool.device)
        X_list = self.X_pool.index_select(0, idx)
        y_list = self.y_pool.index_select(0, idx)
        return X_list, y_list


def set_seed(seed: int) -> None:
    torch.manual_seed(seed)
    np.random.seed(seed)


def rel_l2_error(pred: torch.Tensor, target: torch.Tensor, eps: float = 1e-12) -> float:
    pred = pred.view(-1)
    target = target.view(-1)
    num = torch.norm(pred - target, p=2)
    den = torch.norm(target, p=2).clamp_min(eps)
    return (num / den).item()


def init_xavier_for_gelu(m: nn.Module) -> None:
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)


def target_function(X: torch.Tensor) -> torch.Tensor:
    x0 = X[:, 0]
    x1 = X[:, 1] if X.size(1) > 1 else 0.0
    y = torch.sin(x0) + 0.5 * (x1**2)

    if X.size(1) > 2:
        rest = X[:, 2:]
        y = y + 0.1 * torch.sin(rest).sum(dim=1) + 0.05 * rest.sum(dim=1)

    return y.unsqueeze(-1)


class MLPRegressor(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dims,
        out_dim: int = 1,
        activation: str = "gelu",
    ):
        super().__init__()

        if hidden_dims is None or len(hidden_dims) == 0:
            hidden_dims = [128, 128]

        if isinstance(activation, str):
            act_name = activation.lower()
            if act_name == "relu":
                act_cls = nn.ReLU
            elif act_name == "tanh":
                act_cls = nn.Tanh
            else:
                act_cls = nn.GELU
        else:
            act_cls = activation

        layers = []
        dims = [input_dim] + list(hidden_dims) + [out_dim]
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            if i < len(dims) - 2:
                layers.append(act_cls())
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class DeepEnsemble(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dims,
        activation="gelu",
        num_models: int = 5,
        seeds=None,
        out_dim: int = 1,
        device=None,
        dtype: torch.dtype = torch.float32,
    ):
        super().__init__()
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device
        self.dtype = dtype

        if seeds is None:
            seeds = list(range(num_models))
        assert len(seeds) >= num_models

        modellist = []
        for s in seeds[:num_models]:
            set_seed(s)
            m = MLPRegressor(
                input_dim=input_dim,
                hidden_dims=hidden_dims,
                out_dim=out_dim,
                activation=activation,
            ).to(self.device, dtype=self.dtype)
            m.apply(init_xavier_for_gelu)
            modellist.append(m)

        self.modellist = nn.ModuleList(modellist)
        self.num_models = num_models

        self.to(self.device, dtype=self.dtype)

        self._calib_mean: torch.Tensor | None = None
        self._calib_std: torch.Tensor | None = None
        self._use_calib: bool = False

    def _forward_raw(self, X: torch.Tensor) -> torch.Tensor:
        preds = []
        for m in self.modellist:
            preds.append(m(X).unsqueeze(0))
        return torch.cat(preds, dim=0)

    @torch.no_grad()
    def set_prediction_calibration(self, X_calib: torch.Tensor):
        X_calib = X_calib.to(self.device, dtype=self.dtype)

        mean_c, _, _ = self.predict(X_calib, normalize=False)

        mu_calib = mean_c.mean()
        var_calib = mean_c.var(unbiased=True)

        eps = torch.tensor(1e-12, device=self.device, dtype=self.dtype)
        var_calib = torch.clamp(var_calib, min=eps)
        sigma_calib = torch.sqrt(var_calib)

        self._calib_mean = mu_calib.detach()
        self._calib_std = sigma_calib.detach()
        self._use_calib = True

    def predict(self, X: torch.Tensor, normalize: bool = True):
        X = X.to(self.device, dtype=self.dtype)
        for m in self.modellist:
            m.eval()

        preds = self._forward_raw(X)
        preds_2d = preds.squeeze(-1)

        mean = preds_2d.mean(dim=0)
        std = preds_2d.std(dim=0, unbiased=True)

        if preds_2d.size(0) > 1:
            centered = preds_2d - mean.unsqueeze(0)
            m = preds_2d.size(0)
            B = preds_2d.size(1)
            cov = centered.transpose(0, 1).matmul(centered) / (m - 1)
        else:
            B = preds_2d.size(1)
            cov = torch.zeros(B, B, device=self.device, dtype=self.dtype)

        mean = mean.to(dtype=self.dtype)
        std = std.to(dtype=self.dtype)
        cov = cov.to(dtype=self.dtype)

        if normalize and self._use_calib:
            mu_calib = self._calib_mean.to(self.device, dtype=self.dtype)
            sigma_calib = self._calib_std.to(self.device, dtype=self.dtype)

            var_calib = sigma_calib**2
            mean = (mean - mu_calib) / sigma_calib
            std = std / sigma_calib
            cov = cov / var_calib

        return mean, std, cov

    def _build_loss(self, loss_type: str):
        return build_loss(loss_type)

    def fit(
        self,
        X: torch.Tensor,
        y: torch.Tensor,
        steps: int,
        lr: float,
        batch_size: int,
        weight_decay: float = 0.0,
        loss_type: str = "mse",
        list_size: int | None = None,
        lists_per_step: int | None = None,
        use_amp: bool = True,
        log_every: int = 100,
    ):
        X = X.detach().to(dtype=self.dtype)
        y = y.detach().to(dtype=self.dtype)
        if weight_decay < 0:
            raise ValueError("weight_decay must be >= 0")

        if loss_type.lower() == "mse":
            dataset = TensorDataset(X, y)
            loader = DataLoader(
                dataset,
                batch_size=min(batch_size, X.size(0)),
                shuffle=True,
                drop_last=False,
            )
        else:
            assert list_size is not None and lists_per_step is not None
            X_pool = X.to(self.device, dtype=self.dtype)
            y_pool = y.to(self.device, dtype=self.dtype)
            total_lists_per_epoch = batch_size * lists_per_step
            dataset = ListRankDataset(X_pool, y_pool, list_size, total_lists_per_epoch)
            loader = DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=True,
                num_workers=0,
            )

        optimizer = torch.optim.AdamW(
            self.parameters(), lr=lr, weight_decay=float(weight_decay)
        )
        criterion = self._build_loss(loss_type)
        amp_enabled = (
            use_amp and self.device.type == "cuda" and self.dtype == torch.float32
        )
        scaler = torch.amp.GradScaler("cuda", enabled=amp_enabled)

        for m in self.modellist:
            m.train()

        for step in range(1, steps + 1):
            total = 0.0
            n = 0

            for batch in loader:
                optimizer.zero_grad(set_to_none=True)

                if loss_type.lower() == "mse":
                    Xb, yb = batch
                    Xb = Xb.to(self.device, dtype=self.dtype)
                    yb = yb.to(self.device, dtype=self.dtype)

                    use_cuda_amp = amp_enabled
                    with torch.amp.autocast(device_type="cuda", enabled=use_cuda_amp):
                        preds = self._forward_raw(Xb)
                        err = preds - yb.unsqueeze(0)
                        per_model_mse = (err**2).mean(dim=(1, 2))
                        loss = per_model_mse.mean()
                        batch_size_effective = Xb.size(0)
                else:
                    Xb, yb = batch
                    Xb = Xb.to(self.device, dtype=self.dtype)
                    yb = yb.to(self.device, dtype=self.dtype)
                    B, L, D = Xb.shape
                    X_flat = Xb.view(B * L, D)

                    use_cuda_amp = amp_enabled
                    with torch.amp.autocast(device_type="cuda", enabled=use_cuda_amp):
                        preds_flat = self._forward_raw(X_flat)
                        preds_list = preds_flat.view(self.num_models, B, L)
                        per_model_losses = []
                        for i in range(self.num_models):
                            per_model_losses.append(criterion(preds_list[i], yb))
                        per_model_losses = torch.stack(per_model_losses)
                        loss = per_model_losses.mean()
                        batch_size_effective = B

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                total += loss.item() * batch_size_effective
                n += batch_size_effective

            if step % log_every == 0:
                print(f"step {step:4d}/{steps} | loss={total / n:.6f}")
