import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from torch.utils.data import TensorDataset
import random

def moons_dataset(dups_dist,dups,outlier_size,test_outliers,n=8000):
    #print(outlier_size)
    L2_Similarity=dups_dist
    near_duplicate_size=dups
    #Data Labels Training Set
    data_label_inlier = ["Inlier"] * (n-outlier_size-near_duplicate_size)
    data_label_dup = ["Duplicates"] * (near_duplicate_size)
    data_label_outlier = ["Outlier"] * (outlier_size)
    data_labels = data_label_inlier+data_label_dup+data_label_outlier
    #Training Dataset Inliers
    X, _ = make_moons(n_samples=n-outlier_size-near_duplicate_size, random_state=2, noise=0.03)
    X[:, 0] = (X[:, 0] + 0.3) * 2 -1
    X[:, 1] = (X[:, 1] + 0.3) * 3 
    #print(f"==>> X: {X}")
    rng = np.random.default_rng(1)

    #Near-Duplicate Training Set
    near_duplicates=[]
    near_duplicate_instance=[]
    for _ in range(near_duplicate_size):
        random_index = random.randint(0, len(X) - 1)  # Generate a random index within the valid range
        data_labels[random_index]="Duplicates"
        random_instance = X[random_index]
        near_duplicate_instance= random_instance.copy()
        near_duplicate_instance[0]+=np.random.uniform(-L2_Similarity, L2_Similarity)
        near_duplicate_instance[1] + np.random.uniform(-L2_Similarity, L2_Similarity)
        near_duplicates.append(near_duplicate_instance.tolist())
    #print(f"==>> near_duplicates: {near_duplicates}")
    #Outlier Training Dataset
    x = rng.uniform(-0.5, 0.5, outlier_size)
    y = rng.uniform(-0.5, 0.5, outlier_size)
    norm = np.sqrt(x**2 + y**2) + 1e-10
    x /= norm
    y /= norm
    theta = 2 * np.pi * rng.uniform(0, 1, outlier_size)
    r = rng.uniform(0, 0.03, outlier_size)
    x += r * np.cos(theta)
    y += r * np.sin(theta)
    outlier_X = np.stack((x, y), axis=1)
    outlier_X[:,1]=outlier_X[:,1]-4
    if outlier_size==8000:
        combined_X = outlier_X
    elif len(X)==8000: 
        combined_X = X
    elif outlier_size==0:
        combined_X = np.vstack((X,near_duplicates))
    elif near_duplicate_size==0:
        combined_X = np.vstack((X,outlier_X))
    else:
        print(f"==>> X: {len(X)}")
        print(f"==>> near_duplicates: {len(near_duplicates)}")
        print(f"==>> outlier_X: {len(outlier_X)}")
        combined_X = np.vstack((X,near_duplicates,outlier_X))

    plt.figure(figsize=(10, 10))
    plt.scatter(combined_X.astype(np.float32)[:, 0], combined_X.astype(np.float32)[:, 1])
    plt.xlim(-6, 6)
    plt.ylim(-6, 6)
    plt.savefig(f"exps/train_dataset.png")
    plt.close() 
    print(f"==>> test_outliers: {test_outliers}")
    
    if test_outliers==False:
        outlier_size=0
    #test data labels
    data_label_inlier_test = ["Inlier"] * (n-outlier_size-near_duplicate_size)
    data_label_dup_test = ["Duplicates"] * (near_duplicate_size)
    data_label_outlier_test = ["Outlier"] * (outlier_size)
    data_labels_test = data_label_inlier_test+data_label_dup_test+data_label_outlier_test
    #Test Dataset Inliers:
    X_test, _ = make_moons(n_samples=n-outlier_size-near_duplicate_size, random_state=77, noise=0.03)
    X_test[:, 0] = (X_test[:, 0] + 0.3) * 2 -1
    X_test[:, 1] = (X_test[:, 1] + 0.3) * 3
    #Near-Duplicate Test Set
    near_duplicates_test=[]
    near_duplicate_instance_test=[]
    for _ in range(near_duplicate_size):
        random_index = random.randint(0, len(X_test) - 1)  # Generate a random index within the valid range
        data_labels_test[random_index]="Duplicates"
        random_instance = X_test[random_index]
        near_duplicate_instance_test= random_instance.copy()
        near_duplicate_instance_test[0]+=np.random.uniform(-L2_Similarity, L2_Similarity)
        near_duplicate_instance_test[1]+=np.random.uniform(-L2_Similarity, L2_Similarity)
        near_duplicates_test.append(near_duplicate_instance_test.tolist())
    
    
    #Test Outliers
    rng = np.random.default_rng(123)
    x_test = rng.uniform(-0.5, 0.5, outlier_size)
    y_test = rng.uniform(-0.5, 0.5, outlier_size)
    norm = np.sqrt(x_test**2 + y_test**2) + 1e-10
    x_test /= norm
    y_test /= norm
    theta = 2 * np.pi * rng.uniform(0, 1, outlier_size)
    r = rng.uniform(0, 0.03, outlier_size)
    x_test += r * np.cos(theta)
    y_test += r * np.sin(theta)
    outlier_X_test = np.stack((x_test, y_test), axis=1)
    outlier_X_test[:,1]=outlier_X_test[:,1]-4
    if outlier_size==8000:
        combined_X_test = outlier_X_test
    elif len(X_test)==8000: 
        combined_X_test = X_test
    elif outlier_size==0:
        combined_X_test = np.vstack((X_test,near_duplicates_test))
    elif near_duplicate_size==0:
        combined_X_test = np.vstack((X_test,outlier_X_test))
    else:
        print(f"==>> X_test: {len(X_test)}")
        print(f"==>> near_duplicates_test: {len(near_duplicates_test)}")
        print(f"==>> outlier_X_test: {len(outlier_X_test)}")
        combined_X_test = np.vstack((X_test,near_duplicates_test,outlier_X_test))

    plt.figure(figsize=(10, 10))
    plt.scatter(combined_X_test.astype(np.float32)[:, 0], combined_X_test.astype(np.float32)[:, 1])
    plt.xlim(-6, 6)
    plt.ylim(-6, 6)
    plt.savefig(f"exps/test_dataset.png")
    plt.close() 
    return TensorDataset(torch.from_numpy(combined_X.astype(np.float32))),combined_X.astype(np.float32),data_labels,combined_X_test.astype(np.float32),data_labels_test


def line_dataset(n=8000):
    rng = np.random.default_rng(42)
    x = rng.uniform(-0.5, 0.5, n)
    y = rng.uniform(-1, 1, n)
    X = np.stack((x, y), axis=1)
    X *= 4
    return TensorDataset(torch.from_numpy(X.astype(np.float32)))


def circle_dataset(n=8000):
    rng = np.random.default_rng(42)
    x = np.round(rng.uniform(-0.5, 0.5, n)/2, 1)*2
    y = np.round(rng.uniform(-0.5, 0.5, n)/2, 1)*2
    norm = np.sqrt(x**2 + y**2) + 1e-10
    x /= norm
    y /= norm
    theta = 2 * np.pi * rng.uniform(0, 1, n)
    r = rng.uniform(0, 0.03, n)
    x += r * np.cos(theta)
    y += r * np.sin(theta)
    X = np.stack((x, y), axis=1)
    X *= 3
    return TensorDataset(torch.from_numpy(X.astype(np.float32)))


def dino_dataset(n=8000):
    df = pd.read_csv("static/DatasaurusDozen.tsv", sep="\t")
    df = df[df["dataset"] == "dino"]

    rng = np.random.default_rng(42)
    ix = rng.integers(0, len(df), n)
    x = df["x"].iloc[ix].tolist()
    x = np.array(x) + rng.normal(size=len(x)) * 0.15
    y = df["y"].iloc[ix].tolist()
    y = np.array(y) + rng.normal(size=len(x)) * 0.15
    x = (x/54 - 1) * 4
    y = (y/48 - 1) * 4
    X = np.stack((x, y), axis=1)
    return TensorDataset(torch.from_numpy(X.astype(np.float32)))


def get_dataset(name, outliers,dups,dups_dist,test_outliers,n=8000,):
    if name == "moons":
        return moons_dataset(dups_dist,dups,outliers,test_outliers,n)
    elif name == "dino":
        return dino_dataset(n)
    elif name == "line":
        return line_dataset(n)
    elif name == "circle":
        return circle_dataset(n)
    else:
        raise ValueError(f"Unknown dataset: {name}")
if __name__ == "__main__":
    moons_dataset(10)