import os
import zipfile
import pandas as pd
from pathlib import Path
from urllib.request import urlretrieve
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from source.constants import UCI_PATH as _UCI_PATH_ORIGINAL

UCI_PATH = Path(_UCI_PATH_ORIGINAL)
if not UCI_PATH.exists():
    # If the UCI_PATH does not exist, create it
    UCI_PATH.mkdir(parents=True, exist_ok=True)

URLS = {
    "uci_ymsd": "https://archive.ics.uci.edu/ml/machine-learning-databases/00203/YearPredictionMSD.txt.zip",
    "uci_sgemm": "https://archive.ics.uci.edu/static/public/440/sgemm+gpu+kernel+performance.zip",
    "uci_ccpp": "https://archive.ics.uci.edu/static/public/294/combined+cycle+power+plant.zip",
    "uci_casp": "https://archive.ics.uci.edu/ml/machine-learning-databases/00265/CASP.csv",
    "uci_news": "https://archive.ics.uci.edu/static/public/332/online+news+popularity.zip",
    "uci_blog": "https://archive.ics.uci.edu/static/public/304/blogfeedback.zip"
}

def _apply_scaler(X_train, X_test, scaler):
    scaler = scaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)
    return X_train, X_test

def _transform_target(y_train, y_test, target_min, target_max):
    return (y_train - target_min) / (target_max - target_min), \
          (y_test - target_min) / (target_max - target_min) 


def load_ymsd(scaler = StandardScaler):
    assert scaler in [StandardScaler, MinMaxScaler], "Scaler must be either StandardScaler or MinMaxScaler"

    if not (UCI_PATH / "YearPredictionMSD.txt").exists():
        # download and unzip the dataset
        zip_path = UCI_PATH / "YearPredictionMSD.txt.zip"
        urlretrieve(URLS["uci_ymsd"], zip_path)

        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(UCI_PATH)

        # Remove the zip file after extraction
        os.remove(zip_path)

    df = pd.read_csv(UCI_PATH / "YearPredictionMSD.txt", header=None)

    X = df.iloc[:, 1:].values
    y = df.iloc[:, 0].values

    # UCI provides predefined split: first 463,715 = train, last 51,630 = test
    X_train, y_train = X[:463715], y[:463715]
    X_test,  y_test  = X[463715:], y[463715:]

    # Scaling
    X_train, X_test = _apply_scaler(X_train, X_test, scaler)
    y_train, y_test = _transform_target(y_train, y_test, 1922, 2012)

    return X_train, y_train, X_test, y_test

def load_sgemm(scaler = StandardScaler):
    assert scaler in [StandardScaler, MinMaxScaler], "Scaler must be either StandardScaler or MinMaxScaler"
    
    if not (UCI_PATH / "sgemm_product.csv").exists():
        # If the local path does not exist, download the dataset
        
        zip_path = UCI_PATH / "sgemm+gpu+kernel+performance.zip"
        urlretrieve(URLS["uci_sgemm"], zip_path)

        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            # extract only the CSV file
            if "sgemm_product.csv" in zip_ref.namelist():
                zip_ref.extract("sgemm_product.csv", UCI_PATH)
            else:
                raise FileNotFoundError("sgemm_product.csv not found in the zip file.")

        # Remove the zip file after extraction
        os.remove(zip_path)

    df = pd.read_csv(UCI_PATH / "sgemm_product.csv")

    # Features & target (mean runtime over 4 runs)
    X = df.iloc[:, :-4].values
    y = df.iloc[:, -4:].mean(axis=1).values

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=2357
    )

    # scaling
    X_train, X_test = _apply_scaler(X_train, X_test, scaler)
    y_train, y_test = _transform_target(y_train, y_test, 13.25, 3397.08)

    return X_train, y_train, X_test, y_test

def load_ccpp(scaler = StandardScaler):
    assert scaler in [StandardScaler, MinMaxScaler], "Scaler must be either StandardScaler or MinMaxScaler"
    
    if not (UCI_PATH / "Folds5x2_pp.xlsx").exists():

        # If the local path does not exist, download the dataset
        zip_path = UCI_PATH / "combined+cycle+power+plant.zip"

        urlretrieve(URLS["uci_ccpp"], zip_path)

        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            member = "CCPP/Folds5x2_pp.xlsx"
            if member in zip_ref.namelist():
                target = os.path.join(UCI_PATH, os.path.basename(member))
                with zip_ref.open(member) as src, open(target, "wb") as dst:
                    dst.write(src.read())
            else:
                raise FileNotFoundError("Folds5x2_pp.xlsx not found in the zip file.")
            
        # Remove the zip file after extraction
        os.remove(zip_path)
        
    df = pd.read_excel(UCI_PATH / "Folds5x2_pp.xlsx")

    X = df.iloc[:, :-1].values
    y = df.iloc[:, -1].values

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=2357
    )

    # Scaling
    X_train, X_test = _apply_scaler(X_train, X_test, scaler)
    y_train, y_test = _transform_target(y_train, y_test, 420.26, 495.76)

    return X_train, y_train, X_test, y_test

def load_casp(scaler = StandardScaler):
    assert scaler in [StandardScaler, MinMaxScaler], "Scaler must be either StandardScaler or MinMaxScaler"
    
    if not (UCI_PATH / "CASP.csv").exists():
        # If the local path does not exist, download the dataset
        urlretrieve(URLS["uci_casp"], UCI_PATH / "CASP.csv")
        
    df = pd.read_csv(UCI_PATH / "CASP.csv")

    X = df.iloc[:, 1:].values
    y = df.iloc[:, 0].values

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=2357
    )

    # Scaling
    X_train, X_test = _apply_scaler(X_train, X_test, scaler)
    y_train, y_test = _transform_target(y_train, y_test, y_train.min(), y_train.max())

    return X_train, y_train, X_test, y_test

def load_news(scaler = StandardScaler):
    assert scaler in [StandardScaler, MinMaxScaler], "Scaler must be either StandardScaler or MinMaxScaler"

    if not (UCI_PATH / "OnlineNewsPopularity.csv").exists():
        # If the local path does not exist, download the dataset
        zip_path = UCI_PATH / "OnlineNewsPopularity.zip"
        urlretrieve(URLS["uci_news"], zip_path)
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            member = "OnlineNewsPopularity/OnlineNewsPopularity.csv"
            if member in zip_ref.namelist():
                target = os.path.join(UCI_PATH, os.path.basename(member))
                with zip_ref.open(member) as src, open(target, "wb") as dst:
                    dst.write(src.read())

        # Remove the zip file after extraction
        os.remove(zip_path)

    df = pd.read_csv(UCI_PATH / "OnlineNewsPopularity.csv", sep=',', header=0)

    # Features & target (shares)
    X = df.iloc[:, 1:-1].values   # drop URL, keep all other features
    y = df.iloc[:, -1].values

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=2357
    )

    # Scaling
    X_train, X_test = _apply_scaler(X_train, X_test, scaler)
    y_train, y_test = _transform_target(y_train, y_test, y_train.min(), y_train.max())
    
    return X_train, y_train, X_test, y_test

def load_blog(scaler = StandardScaler):
    assert scaler in [StandardScaler, MinMaxScaler], "Scaler must be either StandardScaler or MinMaxScaler"

    if not (UCI_PATH / "blogData_train.csv").exists():
        # If the local path does not exist, download the dataset
        zip_path = UCI_PATH / "blogfeedback.zip"
        urlretrieve(URLS["uci_blog"], zip_path)

        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            # Extract all blogData*.csv files (train + test splits)
            zip_ref.extractall(UCI_PATH)

        # Remove the zip file after extraction
        os.remove(zip_path)

    # Use the provided training and testing splits
    train_path = UCI_PATH / "blogData_train.csv"
    test_files = [f for f in os.listdir(UCI_PATH) if f.startswith("blogData_test")]

    # Load train and test data
    df_train = pd.read_csv(train_path, header=None)
    df_test = pd.concat([pd.read_csv(UCI_PATH / f, header=None) for f in test_files], ignore_index=True)

    X_train = df_train.iloc[:, :-1].values
    y_train = df_train.iloc[:, -1].values
    X_test = df_test.iloc[:, :-1].values
    y_test = df_test.iloc[:, -1].values

    # Scaling
    X_train, X_test = _apply_scaler(X_train, X_test, scaler)
    y_train, y_test = _transform_target(y_train, y_test, y_train.min(), y_train.max())

    return X_train, y_train, X_test, y_test
