import torch
import pandas as pd
from sklearn.model_selection import train_test_split

from uq_diagcfm.utils import get_device
from uq_diagcfm.paths import DATA_DIR, SURROGATES_DIR

GAS_TURBINE_DATASET_NAME = "gas_turbine"

PARAMETERS = ["R_FB", "L_N_H", "M_D", "R_DL", "M_LD", "P_L"]
LEN_PARAMETERS = len(PARAMETERS)
LABELS = ["Unmix_O", "IO_PD", "G_IFD1"]
LEN_LABELS = len(LABELS)

RAW_DATASET_FILE = "Dataset_SU_200k_Normalized.csv"
RAW_DATASET_FILE_PATH = DATA_DIR / RAW_DATASET_FILE
TRAIN_DATASET_FILE = "Dataset_SU_200k_Normalized_train.csv"
TRAIN_DATASET_FILE_PATH = DATA_DIR / TRAIN_DATASET_FILE
VAL_DATASET_FILE = "Dataset_SU_200k_Normalized_val.csv"
VAL_DATASET_FILE_PATH = DATA_DIR / VAL_DATASET_FILE
TEST_DATASET_FILE = "Dataset_SU_200k_Normalized_test.csv"
TEST_DATASET_FILE_PATH = DATA_DIR / TEST_DATASET_FILE


def prepare_and_save_datasets(
    test_size: float = 0.01, val_size: float = 0.01, random_state: int = 1
):
    df = pd.read_csv(RAW_DATASET_FILE_PATH)
    print(f"Dataset loaded from {RAW_DATASET_FILE_PATH}")
    print(f"Dataset shape: {df.shape}")

    cleaned_df = df.loc[:, PARAMETERS + LABELS]
    print(f"Cleaned dataset shape: {cleaned_df.shape}")

    train_df, tmp_df = train_test_split(
        cleaned_df, test_size=test_size + val_size, random_state=random_state
    )
    val_relative_size = val_size / (test_size + val_size)
    test_df, val_df = train_test_split(
        tmp_df, test_size=val_relative_size, random_state=random_state
    )

    # print sizes
    print(f"Train dataset size: {train_df.shape}")
    print(f"Validation dataset size: {val_df.shape}")
    print(f"Test dataset size: {test_df.shape}")

    train_df.to_csv(TRAIN_DATASET_FILE_PATH, index=False)
    val_df.to_csv(VAL_DATASET_FILE_PATH, index=False)
    test_df.to_csv(TEST_DATASET_FILE_PATH, index=False)
    return


class GasTurbineDataset(torch.utils.data.Dataset):
    def __init__(self, split, transform=None):
        if split == "train":
            csv_file = TRAIN_DATASET_FILE_PATH
        elif split == "val":
            csv_file = VAL_DATASET_FILE_PATH
        elif split == "test":
            csv_file = TEST_DATASET_FILE_PATH
        else:
            raise ValueError(f"Unknown split: {split}")
        self.data_frame = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        row = self.data_frame.iloc[idx]

        x = torch.tensor(row[PARAMETERS].to_numpy(), dtype=torch.float32)
        y = torch.tensor(row[LABELS].to_numpy(), dtype=torch.float32)

        if self.transform:
            x, y = self.transform(x, y)

        return x, y


def make_surrogates():
    device = get_device()
    model_Unmix_O = torch.nn.Sequential(
        torch.nn.Linear(in_features=6, out_features=50),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=50, out_features=100),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=100, out_features=200),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=200, out_features=100),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=100, out_features=50),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=50, out_features=1),
    )
    model_IO_PD = torch.nn.Sequential(
        torch.nn.Linear(in_features=6, out_features=50),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=50, out_features=100),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=100, out_features=200),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=200, out_features=100),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=100, out_features=50),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=50, out_features=1),
    )

    model_IFD1 = torch.nn.Sequential(
        torch.nn.Linear(in_features=6, out_features=50),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=50, out_features=100),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=100, out_features=200),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=200, out_features=100),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=100, out_features=50),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features=50, out_features=1),
    )

    # device="cuda"
    model_Unmix_O.to(device)
    model_IO_PD.to(device)
    model_IFD1.to(device)

    model_Unmix_O.load_state_dict(
        torch.load(
            SURROGATES_DIR / "Surrogate_Unmix_O.pth",
            map_location=device,
        )
    )
    model_IO_PD.load_state_dict(
        torch.load(
            SURROGATES_DIR / "Surrogate_IO_PD.pth",
            map_location=device,
        )
    )
    model_IFD1.load_state_dict(
        torch.load(
            SURROGATES_DIR / "Surrogate_IFD1.pth",
            map_location=device,
        )
    )

    return model_Unmix_O, model_IO_PD, model_IFD1


if __name__ == "__main__":
    import sys

    if len(sys.argv) == 2 and sys.argv[1] == "prepare_datasets":
        prepare_and_save_datasets()
    elif len(sys.argv) == 2 and sys.argv[1] == "test_dataset":
        dataset = GasTurbineDataset(split="train")
        print(f"Dataset length: {len(dataset)}")
        x, y = dataset[0]
        print(f"First sample - x: {x}, y: {y}")
    elif len(sys.argv) == 2 and sys.argv[1] == "test_surrogates":
        model_Unmix_O, model_IO_PD, model_IFD1 = make_surrogates()
        x_sample = torch.randn(1, LEN_PARAMETERS).to(get_device())
        y1 = model_Unmix_O(x_sample)
        y2 = model_IO_PD(x_sample)
        y3 = model_IFD1(x_sample)
        print(f"Surrogate outputs for sample\n{x_sample}")
        print(f"Unmix_O:\n\t{y1}\nIO_PD:\n\t{y2}\nIFD1:\n\t{y3}")
