# Taken from: https://github.com/locuslab/edge-of-stability
# Original paper: Cohen, J. M., Kaur, S., Li, Y., Kolter, J. Z., & Talwalkar, A. (2021).
# "Gradient Descent on Neural Networks Typically Occurs at the Edge of Stability", ICLR 2021.
# If you use this code, please cite the original work.

# Our modifications: lines 13-15, 50-111

import torch
from torch.utils.data import TensorDataset
from math import sqrt
from numpy.polynomial import chebyshev
import numpy as np
from sklearn.datasets import fetch_california_housing, load_breast_cancer as sk_load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


def make_chebyshev_dataset(k, n=10000):
    """
    Generate a dataset of n points evenly spaced on the interval [-1, 1], labeled by the chebyshev polynomial of
    degree k.
    """
    X = torch.linspace(-1, 1, n)
    c = np.zeros(k + 1)
    c[-1] = 1
    y = torch.from_numpy(chebyshev.chebval(X.numpy(), c)).float()
    dataset = TensorDataset(X.unsqueeze(1), y.unsqueeze(1))
    return dataset, dataset


def make_linear_dataset(n, d, seed=0):
    """
    Create a dataset for training a deep linear network with n datapoints of dimension d.
    """
    torch.manual_seed(seed)
    X = (torch.qr(torch.randn(n, d))[0] * sqrt(n)).cuda()
    A = torch.randn(d, d).cuda()
    Y = X.mm(A.t())
    return TensorDataset(X, Y), TensorDataset(X, Y)

def make_linear_dataset(n, d, seed=0):
    """
    Create a dataset for training a deep linear network with n datapoints of dimension d.
    """
    torch.manual_seed(seed)
    X = (torch.qr(torch.randn(n, d))[0] * sqrt(n)).cuda()
    A = torch.randn(d, d).cuda()
    Y = X.mm(A.t())
    return TensorDataset(X, Y), TensorDataset(X, Y)


def make_reverse_sequence_dataset(seq_len=10, vocab_size=10, num_samples=10000):
    X = torch.randint(1, vocab_size, (num_samples, seq_len))  # skip 0 for padding
    Y = torch.flip(X, dims=[1])
    return TensorDataset(X, Y), TensorDataset(X[:1000], Y[:1000])


def load_california_housing(loss: str) -> (TensorDataset, TensorDataset):
    data = fetch_california_housing()
    X = data.data
    y = data.target.reshape(-1, 1)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    scaler_X = StandardScaler()
    scaler_y = StandardScaler()

    X_train = scaler_X.fit_transform(X_train)
    X_test = scaler_X.transform(X_test)

    y_train = scaler_y.fit_transform(y_train)
    y_test = scaler_y.transform(y_test)

    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

    return train_dataset, test_dataset


def load_breastcancer(loss: str) -> (TensorDataset, TensorDataset):
    # use sklearn's breast cancer dataset but treat target as a continuous value
    # so this becomes a regression task (targets are 0/1, we scale them)
    data = sk_load_breast_cancer()
    X = data.data
    y = data.target.reshape(-1, 1).astype(float)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    scaler_X = StandardScaler()
    scaler_y = StandardScaler()

    X_train = scaler_X.fit_transform(X_train)
    X_test = scaler_X.transform(X_test)

    y_train = scaler_y.fit_transform(y_train)
    y_test = scaler_y.transform(y_test)

    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

    return train_dataset, test_dataset