import matplotlib.pyplot as plt
import numpy as np
import time
import itertools
import os
import pandas as pd
from tqdm import tqdm

from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import accuracy_score, classification_report

from model import MLP, GCN, GIN, EGConv, REGConv, DistConv
from utils import bench_clustering, calculate_centroids, kmeans_loss, cosine_similarity_loss
#from dataloader import loadOfficeImage
from evaluation import evaluate
import torch
from torch.nn import Softmax
from torch.utils.data import TensorDataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from torch_geometric.nn.inits import reset
from torch_geometric.utils import subgraph
from torch_geometric.datasets import Planetoid
from torch.nn import ModuleList, ReLU, Sequential
from torch.nn.parameter import Parameter
from torch.nn.functional import normalize
import torch.nn.functional as F
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.nn.dense.linear import Linear
from TDCM import UnrollAttention

dataset = 'synthetic'
# 'synthetic' 'handwritten' 'Cora'
# 'MNIST' 'CIFAR10' 'CIFAR100'
batch_size = 256

train_scores, test_scores = [], []
for num_classes in ([4]):
    for ids in range(1):
        print('num_classes, ids ', num_classes, ids)
        PATH = './data/synthetic'
        my_train_dataset = torch.load(os.path.join(PATH, 'ttrain_{}_{}.pt'.format(num_classes//2, ids)))
        my_test_dataset = torch.load(os.path.join(PATH, 'ttest_{}_{}.pt'.format(num_classes//2, ids)))
        train_data_loader = torch.utils.data.DataLoader(
                my_train_dataset,
                batch_size=batch_size,
                shuffle=True,
                drop_last=True,
                num_workers=8,
            )
        test_data_loader = torch.utils.data.DataLoader(
                my_test_dataset,
                batch_size=batch_size,
                shuffle=True,
                #drop_last=True,
                num_workers=8,
            )
    
        train_data, train_labels = my_train_dataset.tensors[0].numpy(), my_train_dataset.tensors[1].numpy()
        test_data, test_labels = my_test_dataset.tensors[0].numpy(), my_test_dataset.tensors[1].numpy()
        (n_samples_train, n_features), n_classes = train_data.shape, np.unique(train_labels).size
        (n_samples_test, n_features), n_classes = test_data.shape, np.unique(test_labels).size


        # baselines 
        for runs in range(1):
            device = torch.device("cuda:0")

            input_channels = n_features
            hidden_channels = 16 
            output_channels = n_classes 
            encoder = Linear(input_channels, output_channels, bias=False)
            glorot(encoder)
            heads = 1
            num_layers = 5
            unrolled_model = UnrollAttention(
                in_channels = hidden_channels,
                hidden_channels = hidden_channels,
                out_channels = output_channels, 
                encoder = encoder,
                heads=1,
                num_layers=num_layers,
                tau=2.
            ).to(device)


            optimizer = torch.optim.Adam(unrolled_model.parameters(), lr=1e-3, weight_decay=5e-4)  # Define optimizer.
            beta = 5.
            def entropy_loss(pi_k):
                return (torch.log(pi_k) * pi_k).sum()
            def orthogonal_loss(W):
                return ((torch.matmul(W, W.T) - torch.eye(W.shape[0], device=W.device)).sum())**2
            def orthogonal_loss2(W1, W2):
                # for two matrix
                return (torch.matmul(W1, W2.T).diag()**2).mean()

            def train_unroll(model, loader):
                model.train()
                loss_list = []
                x_list, s_list, y_list = [], [], []
                new_x_list, new_centroids, z_list = [], [], []
                softmax_1D = torch.nn.Softmax(dim=0)
                softmax_2D = torch.nn.Softmax(dim=1)
                for x, y in loader:
                    optimizer.zero_grad()  # Clear gradients.
                    x, y = x.to(device), y.to(device)
                    new_x, new_centroids, s, pi_k, loss_mu = model(x)  # Perform a single forward pass.  
                    x_list.append(x.detach().cpu().numpy())
                    s_list.append(s.detach().cpu())
                    y_list.append(y.detach().cpu().numpy())
                    new_x_list.append(new_x.detach().cpu().numpy())  

                    #print(torch.cat(loss_item).view(-1).shape)
                    loss_mu = torch.sum(softmax_1D(model.weights)*torch.cat(loss_mu).view(-1))#torch.sum(torch.cat(loss_item))
                    #print(torch.cat(loss_item))

                    new_x_normalized = new_x / new_x.norm(dim=1)[:, None]
                    s_normalized = s / s.norm(dim=1)[:, None]
                    sample_similarity_matrix = torch.matmul(new_x_normalized, new_x_normalized.T)
                    assign_similarity_matrix = torch.matmul(s_normalized, s_normalized.T)
                    sample_similarity_matrix = softmax_2D(sample_similarity_matrix)
                    assign_similarity_matrix = softmax_2D(assign_similarity_matrix)
                    sample_similarity_matrix = sample_similarity_matrix.view(1,-1)
                    assign_similarity_matrix = assign_similarity_matrix.view(1,-1)

                    combined = torch.cat([sample_similarity_matrix, assign_similarity_matrix]).view(2,-1)
                    l2loss = torch.nn.MSELoss()
                    loss = kmeans_loss(new_x, s, n_classes) + entropy_loss(pi_k)  # Compute the loss solely based on the training nodes.
                    # regularization terms
                    regularization_loss = 0
                    for i, attention_layer in enumerate(model.attention_list):
                        for name, param in attention_layer.named_parameters():
                            regularization_loss += orthogonal_loss(param) 
                    loss += beta*regularization_loss
                    loss_list.append(float(loss))
                    loss.backward()  # Derive gradients.
                    optimizer.step()  # Update parameters based on gradients.
                return np.mean(loss_list), np.concatenate(x_list), torch.concat(s_list), np.concatenate(y_list), np.concatenate(new_x_list)

            def test_unroll(model, loader):
                model.eval()
                pred_list, y_list = [], []
                test_layers = num_layers
                for x, y in loader:
                    x, y = x.to(device), y.to(device)
                    new_x, new_centroids, s, pi_k, _ = model(x)  # Perform a single forward pass.
                    pred = s.argmax(dim=1)
                    pred_list.append(pred.detach().cpu().numpy())
                    y_list.append(y.detach().cpu().numpy())
                return np.concatenate(pred_list), y_list

            unrolled_model = unrolled_model.to(device)
            start_time = time.time()
            per_list = []
            for epoch in tqdm(range(1, 51)):
                loss, x, s, y, new_x = train_unroll(unrolled_model, train_data_loader)
                execution_time = time.time() - start_time


            pred, y_list = test_unroll(unrolled_model, train_data_loader)
            y_true = np.concatenate(y_list)
            nmi, ari, f, acc = evaluate(y_true, pred)
            train_scores.append([num_classes//2, ids, nmi, ari, f, acc]) 

            gm_test = KMeans(init="random", n_clusters=n_classes, n_init=5).fit(test_data)
            centroids_init_test = torch.tensor(gm_test.cluster_centers_, dtype=torch.float).to(device)
            pred, y_list = test_unroll(unrolled_model, test_data_loader)
            y_true = np.concatenate(y_list)
            nmi, ari, f, acc = evaluate(y_true, pred)
            test_scores.append([num_classes//2, ids, nmi, ari, f, acc]) 
            regularization_loss = 0
            for i, attention_layer in enumerate(unrolled_model.attention_list):
                for name, param in attention_layer.named_parameters():
                    regularization_loss += orthogonal_loss(param) 


df_train = pd.DataFrame(data=train_scores, columns=['num_classes', 'ids', 'nmi', 'ari', 'f', 'acc'])
df_test = pd.DataFrame(data=test_scores, columns=['num_classes', 'ids', 'nmi', 'ari', 'f', 'acc'])
save_dir = './res/synthetic'
df_train.to_csv(os.path.join(save_dir, 'model_train.csv'), index=False)
df_test.to_csv(os.path.join(save_dir, 'model_test.csv'), index=False)


