import numpy as np
import os.path as osp
from utils.data import get_dx_dy_datapaths, load_dataset
from utils.paths import get_split_dataset_path
from data.function import TestFunction
from utils.types import NestedFloatList
from typing import Optional
from utils.plot import plot_optimization
from utils.config import get_train_x_range, get_train_y_range
from torch import Tensor
import torch
from matplotlib.figure import Figure
from botorch.utils.transforms import unnormalize
from data.sampler import scale_y
from data.laser_plasma.data_handler import LaserPlasmaDataLoader


def plot_optimization_batch(
    test_function: TestFunction,
    x_query: Tensor,
    y_query: Tensor,
    input_range_list: Optional[NestedFloatList] = None,
    grid_res: int = 2000,
) -> Figure:
    b, _, x_dim = x_query.shape
    _, _, y_dim = y_query.shape
    assert x_dim == test_function.x_dim and y_dim == test_function.y_dim
    device = x_query.device
    x_bounds = test_function.x_bounds.to(device)

    # Create all-ones masks - for API compatibility
    x_mask = torch.ones(x_dim, dtype=torch.bool, device=device)
    y_mask = torch.ones(y_dim, dtype=torch.bool, device=device)

    # Sample true datapoints from the function
    x = unnormalize(
        torch.rand(grid_res, x_dim, device=device),
        x_bounds.transpose(0, 1),
    )
    y = test_function.evaluate(x=x, input_bounds=x_bounds)

    x = x.unsqueeze(0).expand(b, -1, -1)
    y = y.unsqueeze(0).expand(b, -1, -1)

    x_mask = x_mask.unsqueeze(0).expand(b, -1)
    y_mask = y_mask.unsqueeze(0).expand(b, -1)

    if input_range_list is not None:
        x_query_scaled = test_function.scale_inputs(
            x_query,
            input_bounds=input_range_list,
        )
    else:
        x_query_scaled = x_query

    fig = plot_optimization(
        x=x,
        y=y,
        x_query=x_query_scaled,
        y_query=y_query,
        x_mask=x_mask,
        y_mask=y_mask,
    )

    del x, y, x_mask, y_mask

    return fig


dim_dict = {
    "dx1_dy1": (1, 1),
    "dx1_dy2": (1, 2),
    "dx1_dy3": (1, 3),
    "dx2_dy1": (2, 1),
    "dx2_dy2": (2, 2),
    "dx2_dy3": (2, 3),
    "dx3_dy1": (3, 1),
    "dx3_dy2": (3, 2),
    "dx3_dy3": (3, 3),
}


def from_function_name_to_datapaths(function_name, split, subfolder=None):
    dims = dim_dict.get(function_name, None)
    if dims is None:
        return None
    else:
        path = get_split_dataset_path(split=split)
        if subfolder is not None:
            path = osp.join(path, subfolder)
        datapaths, _ = get_dx_dy_datapaths(
            path=path,
            x_dim_list=[dims[0]],
            y_dim_list=[dims[1]],
        )
        return datapaths


def from_seed_to_dataset(datapaths, seed):
    hdf5_path = datapaths[0]
    grp_name = f"dataset_{seed}"
    train_x, train_y, _, _ = load_dataset(
        hdf5_path=hdf5_path,
        grp_name=grp_name,
        device="cpu",
    )

    return train_x, train_y


def get_opt_dataset(
    function_name,
    seed,
    device,
    subfolder=None,
):
    if function_name == "LaserPlasma":
        return get_laser_plasma_data(device=device)
    elif function_name in list(dim_dict.keys()):
        datapaths = from_function_name_to_datapaths(
            function_name=function_name,
            split="test",
            subfolder=subfolder,
        )
        if datapaths is None:
            train_x, train_y, train_x_bounds, train_y_bounds = None, None, None, None
        else:
            train_x, train_y, train_x_bounds, train_y_bounds = from_seed_to_data(
                datapaths=datapaths,
                seed=seed,
            )
    else:
        train_x, train_y, train_x_bounds, train_y_bounds = None, None, None, None
    return train_x, train_y, train_x_bounds, train_y_bounds


def from_seed_to_data(datapaths, seed):
    train_x, train_y = from_seed_to_dataset(
        datapaths=datapaths,
        seed=seed,
    )
    train_x_bounds = get_train_x_range()
    train_y_bounds = get_train_y_range()

    # NOTE Always scale train_y
    train_y = scale_y(y=train_y, domains=train_y_bounds)
    return train_x, train_y, train_x_bounds, train_y_bounds


def get_laser_plasma_data(device):
    data_loader = LaserPlasmaDataLoader(device=device, negate=True)
    train_x, train_y, train_x_bounds, train_y_bounds = data_loader.get_data()
    return train_x, train_y, train_x_bounds, train_y_bounds


def expand_metrics(metric, q_x_next):
    q = q_x_next.shape[1]
    if isinstance(metric, np.ndarray):
        metric = torch.from_numpy(metric).to(
            device=q_x_next.device, dtype=q_x_next.dtype
        )
    metric = metric[:, None].expand(-1, q)
    return metric
