import os.path as osp
import zipfile
from io import StringIO, BytesIO
from pathlib import Path
from typing import Union, List, Tuple, Optional, Callable

import numpy as np
import requests
from unlzw3 import unlzw

import pandas as pd
import torch
from pydgn.data.dataset import DatasetInterface
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from torch_geometric.data import Data


class UCIDatasetInterface(DatasetInterface):
    column_names: List[str] = None
    categorical_columns: List[str] = None
    columns_to_drop: List[str] = None
    target_column: str = None
    uci_url: str = None

    def df_loader(self):
        df = pd.read_csv(self.uci_url, header=None, names=self.column_names)
        return df

    def load_uci_dataset(self):
        df = self.df_loader()

        # drop columns we do not want to use in these experiments
        df = df.drop(columns=self.columns_to_drop)

        # extract target from dataset
        y = df[self.target_column]
        df = df.drop(columns=[self.target_column])

        # replace categorical features with one hot
        df = pd.get_dummies(
            data=df, columns=self.categorical_columns, dummy_na=True
        )

        # fill NaNs with column means in each column
        df = df.fillna(df.mean())

        # we choose to use data as it is
        # numerical_cols = [
        #     c
        #     for c in self.column_names
        #     if c not in self.categorical_columns and c != self.target_column
        # ]
        # ct = ColumnTransformer([
        #     ('scaler', StandardScaler(), numerical_cols)
        # ], remainder='passthrough')
        # df = ct.fit_transform(df)

        # convert values of y into classes
        unique_values = y.unique()
        unique_values_sorted = np.sort(unique_values)  # ascending order

        counter = 0
        for v in unique_values_sorted:
            y = y.replace(v, counter)
            counter += 1

        return df.to_numpy(), y.to_numpy()

    def __init__(
        self,
        root: str,
        name: str,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
        **kwargs,
    ):
        self.root = root
        self.name = name

        assert transform is None, "no preprocessing allowed"
        assert pre_transform is None, "no preprocessing allowed"
        assert pre_filter is None, "no preprocessing allowed"

        super().__init__(root, name, transform, pre_transform, pre_filter)
        self.data = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, self.name, "raw")

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, self.name, "processed")

    @property
    def raw_file_names(self) -> Union[str, List[str], Tuple]:
        return []

    @property
    def processed_file_names(self) -> Union[str, List[str], Tuple]:
        return ["data.py"]

    @property
    def raw_paths(self) -> List[str]:
        return []

    @property
    def processed_paths(self) -> List[str]:
        return [osp.join(self.root, self.name, "processed", "data.py")]

    def download(self):
        pass

    def process(self):
        X, y = self.load_uci_dataset()

        d = Data(
            x=torch.tensor(X, dtype=torch.float),
            edge_index=None,
            edge_attr=None,
            y=torch.tensor(y, dtype=torch.long).unsqueeze(1),
            dtype=torch.float,
        )

        torch.save(d, osp.join(self.root, self.name, "processed", "data.py"))

    def get(self, idx: int) -> Data:
        return self.data

    @property
    def dim_node_features(self) -> int:
        return self.data.x.shape[1]

    @property
    def dim_edge_features(self) -> int:
        return 0

    @property
    def dim_target(self) -> int:
        return int(self.data.y.max().item()) + 1

    def __len__(self) -> int:
        return 1  # single graph


class Abalone(UCIDatasetInterface):
    column_names = [
        "Sex",
        "Length",
        "Diameter",
        "Height",
        "Weight.whole",
        "Weight.shucked",
        "Weight.viscera",
        "Weight.shell",
        "Rings",
    ]
    columns_to_drop = []
    categorical_columns = []
    target_column = "Sex"
    uci_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/abalone/abalone.data"


class Adult(UCIDatasetInterface):
    column_names = [
        "age",
        "workclass",
        "fnlwgt",
        "education",
        "education-num",
        "marital-status",
        "occupation",
        "relationship",
        "race",
        "sex",
        "capital-gain",
        "capital-loss",
        "hours-per-week",
        "native-country",
        "class",
    ]
    columns_to_drop = []
    categorical_columns = [
        "workclass",
        "education",
        "marital-status",
        "occupation",
        "relationship",
        "race",
        "sex",
        "native-country",
    ]
    target_column = "class"
    uci_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"


class ElectricalGrid(UCIDatasetInterface):
    column_names = [
        "tau1",
        "tau2",
        "tau3",
        "tau4",
        "p1",
        "p2",
        "p3",
        "p4",
        "g1",
        "g2",
        "g3",
        "g4",
        "stab",
        "stabf",
    ]
    columns_to_drop = ["stab"]
    categorical_columns = []
    target_column = "stabf"
    uci_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00471/Data_for_UCI_named.csv"

    def df_loader(self):
        df = pd.read_csv(self.uci_url)
        return df


class Musk(UCIDatasetInterface):
    column_names = [0, 1] + [i for i in range(168)]
    columns_to_drop = [0, 1]
    categorical_columns = []
    target_column = 168
    uci_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/musk/clean2.data.Z"

    def df_loader(self):
        u = requests.get(self.uci_url)
        uncompressed_data = unlzw(u.content).decode("utf-8")
        s = StringIO(uncompressed_data)
        df = pd.read_csv(s, sep=",", header=None)
        return df


class Waveform(UCIDatasetInterface):
    column_names = [i for i in range(21)]
    columns_to_drop = []
    categorical_columns = []
    target_column = 21
    uci_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/waveform/waveform.data.Z"

    def df_loader(self):
        u = requests.get(self.uci_url)
        uncompressed_data = unlzw(u.content).decode("utf-8")
        s = StringIO(uncompressed_data)
        df = pd.read_csv(s, sep=",", header=None)
        return df


class Isolet(UCIDatasetInterface):
    column_names = [i for i in range(617)]
    columns_to_drop = []
    categorical_columns = []
    target_column = 617
    uci_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/isolet/isolet1+2+3+4.data.Z"
    uci_url_2 = "https://archive.ics.uci.edu/ml/machine-learning-databases/isolet/isolet5.data.Z"

    def df_loader(self):
        u1 = requests.get(self.uci_url)
        uncompressed_data1 = unlzw(u1.content).decode("utf-8")
        s1 = StringIO(uncompressed_data1)
        df1 = pd.read_csv(s1, sep=",", header=None)

        u2 = requests.get(self.uci_url_2)
        uncompressed_data2 = unlzw(u2.content).decode("utf-8")
        s2 = StringIO(uncompressed_data2)
        df2 = pd.read_csv(s2, sep=",", header=None)

        return pd.concat([df1, df2])


class OccupancyDetection(UCIDatasetInterface):
    column_names = [
        "date",
        "Temperature",
        "Humidity",
        "Light",
        "CO2",
        "HumidityRatio",
        "Occupancy",
    ]
    columns_to_drop = ["date"]
    categorical_columns = []
    target_column = "Occupancy"
    uci_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00357/occupancy_data.zip"

    def df_loader(self):
        u = requests.get(self.uci_url)

        with zipfile.ZipFile(BytesIO(u.content), "r") as zip_ref:
            zip_ref.extractall(self.raw_dir)

        path1 = Path(self.raw_dir, "datatraining.txt")
        df1 = pd.read_csv(path1, sep=",")

        path2 = Path(self.raw_dir, "datatest.txt")
        df2 = pd.read_csv(path2, sep=",")

        path3 = Path(self.raw_dir, "datatest2.txt")
        df3 = pd.read_csv(path3, sep=",")

        return pd.concat([df1, df2, df3])


class DryBean(UCIDatasetInterface):
    column_names = [
        "Area",
        "Perimeter",
        "MajorAxisLength",
        "MinorAxisLength",
        "AspectRation",
        "Eccentricity",
        "ConvexArea",
        "EquivDiameter",
        "Extent",
        "Solidity",
        "roundness",
        "Compactness",
        "ShapeFactor1",
        "ShapeFactor2",
        "ShapeFactor3",
        "ShapeFactor4",
        "Class",
    ]
    columns_to_drop = []
    categorical_columns = []
    target_column = "Class"
    uci_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00602/DryBeanDataset.zip"

    def df_loader(self):
        u = requests.get(self.uci_url)

        with zipfile.ZipFile(BytesIO(u.content), "r") as zip_ref:
            zip_ref.extractall(self.raw_dir)

        path = Path(self.raw_dir, "DryBeanDataset", "Dry_Bean_Dataset.xlsx")
        df = pd.read_excel(path)

        return df
