import torch
import numpy as np

from uq_diagcfm.paths import DATA_DIR, SURROGATES_DIR
from uq_diagcfm.utils import get_device

UNIFOIL_DATASET_NAME = "unifoil"

TRAIN_DATAFILE = DATA_DIR / "unifoil_training.dat"
VAL_DATAFILE = DATA_DIR / "unifoil_validation.dat"

LEN_DESIGN_PARAMETERS = 14
LEN_PHYSICAL_PARAMS = 2
LEN_PHYSICAL_PERFORMANCE = 3


def split_unifoil_data(data):
    design_params = data[..., :14]

    physical_params = data[..., 14:16]

    physical_performance = data[..., 16:19]

    return design_params, physical_params, physical_performance


class UnifoilDataset(torch.utils.data.Dataset):
    def __init__(self, split, transform=None):
        if split == "train":
            dat_file = TRAIN_DATAFILE
        elif split == "val":
            dat_file = VAL_DATAFILE
        else:
            raise ValueError(f"Unknown split: {split}")
        self.data = np.loadtxt(dat_file, dtype=np.float32)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data[idx]
        design_params, physical_params, physical_performance = split_unifoil_data(row)

        if self.transform:
            design_params, physical_params, physical_performance = self.transform(
                design_params, physical_params, physical_performance
            )

        return design_params, physical_params, physical_performance


def make_unifoil_surrogate(hidden_dim: int = 512, depth: int = 5, activation: str = "LeakyReLU"):
    """Load the pre-trained Unifoil surrogate model.

    The surrogate maps (design_params, physical_params) -> performance_labels
    - Input: 14 design params + 2 physical params = 16 dimensions
    - Output: 3 performance labels

    Args:
        hidden_dim: Hidden dimension used during training (default: 512)
        depth: Network depth used during training (default: 5)
        activation: Activation function used during training (default: LeakyReLU)

    Returns:
        A PyTorch model that takes [design_params, physical_params] as input
        and outputs performance predictions.
    """
    from uq_diagcfm.models import MLP

    device = get_device()

    input_dim = LEN_DESIGN_PARAMETERS + LEN_PHYSICAL_PARAMS  # 16
    output_dim = LEN_PHYSICAL_PERFORMANCE  # 3

    model = MLP(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_dim=hidden_dim,
        depth=depth,
        dropout=0.0,
        activation=activation,
    )

    model.load_state_dict(
        torch.load(
            SURROGATES_DIR / "Surrogate_Unifoil.pth",
            map_location=device,
        )
    )
    model.to(device)
    model.eval()
    return model


if __name__ == "__main__":
    import sys

    if len(sys.argv) == 2 and sys.argv[1] == "test":

        dat_file = TRAIN_DATAFILE
        data = np.loadtxt(dat_file, dtype=np.float32)

        print("Data shape:", data.shape)

    elif len(sys.argv) == 2 and sys.argv[1] == "test_surrogate":
        # Test the surrogate model on validation data
        surrogate = make_unifoil_surrogate()
        device = get_device()

        # Load some validation data
        val_data = np.loadtxt(VAL_DATAFILE, dtype=np.float32)
        design_params, physical_params, performance = split_unifoil_data(val_data[:100])

        # Prepare input
        x = np.concatenate([design_params, physical_params], axis=1)
        x_tensor = torch.tensor(x, dtype=torch.float32).to(device)

        # Get predictions
        with torch.no_grad():
            preds = surrogate(x_tensor)

        # Calculate MSE
        y_tensor = torch.tensor(performance, dtype=torch.float32).to(device)
        mse = torch.nn.functional.mse_loss(preds, y_tensor)
        print(f"Surrogate MSE on 100 validation samples: {mse.item():.6f}")
        print(f"Sample prediction: {preds[0].cpu().numpy()}")
        print(f"Ground truth:      {performance[0]}")
