import matplotlib.pyplot as plt
import numpy as np
import time
import itertools
import os
import pandas as pd
from tqdm import tqdm

from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import accuracy_score, classification_report

from model import MLP, GCN, GIN, EGConv, REGConv, DistConv
from utils import bench_clustering, calculate_centroids, kmeans_loss, cosine_similarity_loss
from evaluation import evaluate
import torch
from torch.nn import Softmax
from torch.utils.data import TensorDataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from torch_geometric.nn.inits import reset
from torch_geometric.utils import subgraph
from torch_geometric.datasets import Planetoid
from torch.nn import ModuleList, ReLU, Sequential
from torch.nn.parameter import Parameter
from torch.nn.functional import normalize
import torch.nn.functional as F
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.nn.dense.linear import Linear

class EncoderModel(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super().__init__()
        self.encoder = MLP(
            input_channels, hidden_channels, output_channels, 
            layers=4, 
            varied_hidden_channels=[500, 2000, 2000]
        )
        self.softmax = Softmax(dim=1)

    def forward(self, x):
        z = self.encoder(x)
        return z
def symmetric(X):
    return X.triu() + X.triu(1).transpose(-1, -2)
class AttentionLayer(torch.nn.Module):
    def __init__(self, 
                 in_channels,
                 out_channels, 
                 heads=1,
                 negative_slope=.1,
                 tau=1.
                ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.W_Q = Parameter(data=torch.eye(self.out_channels), requires_grad=True)
        self.W_K = Parameter(data=torch.eye(self.out_channels), requires_grad=False)
        self.W_V = Parameter(data=torch.eye(self.out_channels), requires_grad=False)
        self.softmax_2D = torch.nn.Softmax(dim=1)
        self.softmax_1D = torch.nn.Softmax(dim=0)
        self.sim_loss = torch.nn.MSELoss()
        self.tau = tau

    def forward(self, x, mu):
        H, C = self.heads, self.out_channels
        x_transformed = torch.matmul(x, symmetric(self.W_V))

        diff_x_mu = x.view(-1,1,C) - mu.view(1,-1,C)

        alpha_x = torch.matmul(diff_x_mu, symmetric(self.W_Q)) # N * K * d
        alpha_mu = torch.matmul(diff_x_mu, symmetric(self.W_Q)) # N * K * d

        s = (alpha_x * alpha_mu).sum(dim=-1)
        s = F.leaky_relu(s, self.negative_slope)
        s = (-1) * s
        s = s / self.tau
        s = self.softmax_2D(s)

        N_k = s.sum(dim=0) 
        pi_k = N_k / (N_k.sum()+1e-12)
        new_mu = torch.matmul(s.T, x_transformed) / (N_k.view(-1,1)+1e-12)
        new_mu_normalized  = new_mu / new_mu.norm(dim=1)[:, None]
        sim_x = torch.matmul(new_mu_normalized, new_mu_normalized.T)
        loss = sim_x.sum()

        return x_transformed, new_mu, s, pi_k, loss

class UnrollAttention(torch.nn.Module):
    def __init__(self, 
                 in_channels,
                 hidden_channels,
                 out_channels, 
                 encoder,
                 heads=1,
                 num_layers=1,
                 negative_slope=.1,
                 tau=1.
                ):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.encoder = encoder
        self.heads = heads
        self.num_layers = num_layers
        self.weights = Parameter(
            data=torch.ones(self.num_layers, dtype=torch.float), 
            requires_grad=True
        )
        device = torch.device("cuda:0")
        self.centroids_init = torch.eye(self.out_channels, dtype=torch.float).to(device)
        self.attention_list = ModuleList()
        assert num_layers >= 1
        for _ in range(num_layers):
            self.attention_list.append(
                AttentionLayer(out_channels, out_channels, heads=heads, negative_slope=negative_slope, tau=tau)
            )

    def forward(self, x):
        #x = self.encoder(x)
        mean_x, var_x = x.mean(dim=0).view(1,-1), x.std(dim=0).view(1,-1)
        x = (x - mean_x) / (var_x + 1e-12)
        new_centroids = self.centroids_init
        loss_list = []
        for i, attention_layer in enumerate(self.attention_list):
            x, new_centroids, s, pi_k, loss = attention_layer(x, new_centroids)
            mean_x, var_x = x.mean(dim=0).view(1,-1), x.std(dim=0).view(1,-1)
            x = (x - mean_x) / (var_x + 1e-12)
            new_centroids = (new_centroids - mean_x) / (var_x + 1e-12)
            loss_list.append(loss.view(-1,1))
        return x, new_centroids, s, pi_k, loss_list