import matplotlib.pyplot as plt
import numpy as np
import time
import itertools
import os
from tqdm import tqdm

from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import accuracy_score, classification_report

from model import MLP, GCN, GIN, EGConv, REGConv, DistConv
from utils import bench_clustering, calculate_centroids, kmeans_loss, cosine_similarity_loss
#from dataloader import loadOfficeImage
from evaluation import evaluate
import torch
from torch.nn import Softmax
from torch.utils.data import TensorDataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from torch_geometric.nn.inits import reset
from torch_geometric.utils import subgraph
from torch_geometric.datasets import Planetoid

dataset = 'synthetic'
# 'synthetic' 'handwritten' 'Cora'
# 'MNIST' 'CIFAR10' 

for num_classes in tqdm([4,6,10,20]):
    print('num_classes=', num_classes)
    for ids in tqdm(range(10)):
        # synthetic data
        dim = 2
        N = 1000
        R = 50
        r = 10 / (num_classes/4)
        E = 30000
        homo_ratio = 0.3
        sampled_points = []
        center_points = []
        for t in range(num_classes // 2):
            center_point = np.random.random(num_classes // 2) * 2 - 1
            center_point = center_point / np.linalg.norm(center_point, 2)
            center_points.append(R*center_point)
        cov_matrix = []
        for t in range(num_classes // 2):
            cov_item = []
            for j in range(num_classes // 2):
                if t == j:
                    cov_item.append(r**2)
                else:
                    cov_item.append(0)
            cov_matrix.append(cov_item)

        for t in range(num_classes // 2):
            center_point = center_points[t]
            sampled_points.append(
                np.random.multivariate_normal(
                    mean=center_point, 
                    cov=cov_matrix, 
                    size=N
                )
            )

        for t in range(num_classes // 2):
            shifted_point = np.random.random(num_classes // 2) * 2 - 1
            shifted_point = shifted_point / np.linalg.norm(shifted_point, 2)
            center_point = center_points[t] + 5*r*shifted_point
            sampled_points.append(
                np.random.multivariate_normal(
                    mean=center_point, 
                    cov=cov_matrix, 
                    size=N
                )
            )

        data = np.concatenate(sampled_points)
        labels = np.arange(num_classes*N) // N

        train_data, test_data = data[:len(data)//2], data[len(data)//2:]
        train_labels, test_labels = labels[:len(labels)//2], labels[len(labels)//2:]-num_classes//2


        # extract basic stat
        (n_samples_train, n_features), n_classes = train_data.shape, np.unique(train_labels).size
        (n_samples_test, n_features), n_classes = test_data.shape, np.unique(test_labels).size
        
        my_train_dataset = TensorDataset(
            torch.tensor(train_data, dtype=torch.float), 
            torch.tensor(train_labels, dtype=torch.long)
        ) # create your datset
        batch_size = 256
        train_data_loader = torch.utils.data.DataLoader(
                my_train_dataset,
                batch_size=batch_size,
                shuffle=True,
                drop_last=True,
                num_workers=8,
            )
        my_test_dataset = TensorDataset(
            torch.tensor(test_data, dtype=torch.float), 
            torch.tensor(test_labels, dtype=torch.long)  
        ) # create your datset
        test_data_loader = torch.utils.data.DataLoader(
                my_test_dataset,
                batch_size=batch_size,
                shuffle=True,
                #drop_last=True,
                num_workers=8,
            )
        PATH = './data/synthetic'

        torch.save(my_train_dataset, os.path.join(PATH, 'ttrain_{}_{}.pt'.format(num_classes//2, ids)))
        torch.save(my_test_dataset, os.path.join(PATH, 'ttest_{}_{}.pt'.format(num_classes//2, ids)))
