import os
import sys
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

import torch
from torch.nn import Linear, LayerNorm
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.logging import log
from torch_geometric.nn import GCNConv
from torch_geometric.utils import degree, remove_self_loops, add_self_loops, dense_to_sparse

from layer import Propagate
from dataset_utils import DataLoader
from utils import random_class_balance_splits, cal_heterophilous_ratio_torchversion


def get_cm_of_MLP_with_GCs(num_GCs, data, input_x, input_edge_index):
    # data.edge_attr 如果不是None可能会出错。
    assert data.edge_attr is None
    hidden = 512
    lr = 0.005
    epochs = 1000
    dropout = 0.2

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    class GCN(torch.nn.Module):
        def __init__(self, in_channels, hidden_channels, out_channels):
            super().__init__()
            self.in_lin = Linear(in_channels, hidden_channels)
            self.out_lin = Linear(hidden_channels, out_channels)
            self.lin1 = Linear(hidden_channels, hidden_channels)
            self.lin2 = Linear(hidden_channels, hidden_channels)
            self.norm1 = LayerNorm(hidden_channels)
            self.norm2 = LayerNorm(hidden_channels)
            self.out_norm = LayerNorm(hidden_channels)

        def forward(self, x, edge_index, edge_weight=None):
            x = self.in_lin(x).relu()
            x = F.dropout(x, p=dropout, training=self.training)

            x = self.norm1(x)
            x = self.lin1(x).relu()
            x = F.dropout(x, p=dropout, training=self.training)

            x = self.norm2(x)
            x = self.lin2(x).relu()
            x = F.dropout(x, p=dropout, training=self.training)

            for i in range(num_GCs):
                x = Propagate()(x, edge_index)

            x = self.out_norm(x)
            x = self.out_lin(x)

            return x
    
    # class GCN(torch.nn.Module):
    #     def __init__(self, in_channels, hidden_channels, out_channels):
    #         super().__init__()
    #         self.in_lin = Linear(in_channels, hidden_channels)
    #         self.out_lin = Linear(hidden_channels, out_channels)
    #         self.conv1 = GCNConv(hidden_channels, hidden_channels)
    #         self.conv2 = GCNConv(hidden_channels, hidden_channels)
    #         self.norm1 = LayerNorm(hidden_channels)
    #         self.norm2 = LayerNorm(hidden_channels)
    #         self.out_norm = LayerNorm(hidden_channels)

    #     def forward(self, x, edge_index, edge_weight=None):
    #         x = F.dropout(x, p=dropout, training=self.training)
    #         x0 = self.in_lin(x).relu()

    #         x1 = F.dropout(x0, p=dropout, training=self.training)
    #         x1 = self.norm1(x1)
    #         x1 = self.conv1(x1, edge_index)
    #         x1 = x1.relu()

    #         x2 = F.dropout(x1, p=dropout, training=self.training)
    #         x2 = self.norm2(x2)
    #         x2 = self.conv2(x2, edge_index)
    #         x2 = x2.relu()
            
    #         x = F.dropout(x2, p=dropout, training=self.training)
    #         x = self.out_norm(x)
    #         x = self.out_lin(x)

    #         return x


    model = GCN(dataset.num_features, hidden, dataset.num_classes)
    model, data = model.to(device), data.to(device)
    input_x, input_edge_index = input_x.to(device), input_edge_index.to(device)
    optimizer = torch.optim.Adam([
        dict(params=model.parameters(), weight_decay=5e-4),
    ], lr=lr)


    def train():
        model.train()
        optimizer.zero_grad()
        out = model(input_x, input_edge_index, data.edge_attr)
        loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        return float(loss)


    @torch.no_grad()
    def test():
        model.eval()
        pred = model(input_x, input_edge_index, data.edge_attr).argmax(dim=-1)

        accs = []
        for mask in [data.train_mask, data.val_mask, data.test_mask]:
            accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
        return accs, pred


    best_val_acc = final_test_acc = 0
    for epoch in range(1, epochs + 1):
        loss = train()
        accs, pred = test()
        train_acc, val_acc, tmp_test_acc = accs
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = tmp_test_acc
        if epoch % 200 == 0:
            log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)
    print(f"Best Test: {test_acc:.4f}")
    cm = confusion_matrix(data.y.cpu()[data.test_mask], pred.cpu()[data.test_mask])
    return cm, test_acc

dname = 'amazon_ratings'
dataset, data = DataLoader(dname)
print(dataset)
print(data)
print(dataset.num_classes)
print(data.y)
y_one_hot = torch.zeros(data.y.shape[0], dataset.num_classes).scatter_(1, data.y.view(-1, 1), 1.0)
print(y_one_hot)

train_rate = 0.6
val_rate = 0.2
# number of train samples per class (averaged)
percls_trn = int(round(train_rate*len(data.y)/dataset.num_classes))
# number of validation samples
val_lb = int(round(val_rate*len(data.y)))

# data = random_class_balance_splits(data, dataset.num_classes, train_rate, val_rate)
data.train_mask = data.train_masks[0]
data.val_mask = data.val_masks[0]
data.test_mask = data.test_masks[0]

edge_index, _ = remove_self_loops(data.edge_index)
row, col = edge_index
deg = degree(row, data.x.size(0))

## 每一类的averaged degree
c_degs = []
c_nodes = []
for c in range(dataset.num_classes):
    c_deg = deg[data.y==c]
    c_nodes.append(sum(data.y==c).view(-1).cpu())
    c_degs.append(c_deg.mean(dim=0).view(-1).cpu())
c_degs = torch.cat(c_degs, dim=0)
c_nodes = torch.cat(c_nodes, dim=0)
print(c_degs)
print(c_nodes)

node_local_distribution = Propagate(add_self_loop=False)(y_one_hot.cpu(), data.edge_index.cpu())
m_hat_means, m_hat_stds = [], []
for c in range(dataset.num_classes):
    class_nld = node_local_distribution[data.y==c]
    m_hat_means.append(class_nld.mean(dim=0).view(1,-1))
    m_hat_stds.append(class_nld.std(dim=0).view(1,-1))
    
m_hat_means = torch.cat(m_hat_means, dim=0)
m_hat_stds = torch.cat(m_hat_stds, dim=0)

fig, ax = plt.subplots(1, 3, figsize=(15, 5))

sns.heatmap(data=m_hat_means,vmin=0,vmax=1,annot=True,fmt=".2f", square=True, cmap="Blues", ax=ax[1])
ax[1].set_title(f'{dname} M_hat mean')
sns.heatmap(data=m_hat_stds,vmin=0,vmax=1,annot=True,fmt=".2f", square=True, cmap="Blues", ax=ax[2])
ax[2].set_title(f'{dname} M_hat std')

cldm_matrix = cal_heterophilous_ratio_torchversion(m_hat_means, c_degs, 1)
sns.heatmap(data=cldm_matrix,vmin=0,annot=True,fmt=".2f", square=True, cmap="Blues", ax=ax[0])
ax[0].set_title(f'{dname} M_hat x D')

plt.show()
plt.savefig("amazon_ratings.jpg")

fig, ax = plt.subplots(2, 3, figsize=(15, 10))

MLP_cm, MLP_acc = get_cm_of_MLP_with_GCs(0, data, data.x, data.edge_index)
sns.heatmap(data=MLP_cm,annot=True,fmt=".2f", square=True, cmap="Blues", ax=ax[0][0])
ax[0][0].set_xlabel("Predicted Labels")
ax[0][0].set_ylabel("True Labels")
ax[0][0].set_title(f'MLP: {MLP_acc:.4f}')

GCN1_cm, GCN1_acc = get_cm_of_MLP_with_GCs(1, data, data.x, data.edge_index)
sns.heatmap(data=GCN1_cm,annot=True,fmt=".2f", square=True, cmap="Blues", ax=ax[0][1])
ax[0][1].set_xlabel("Predicted Labels")
ax[0][1].set_ylabel("True Labels")
ax[0][1].set_title(f'GCN1: {GCN1_acc:.4f}')

GCN2_cm, GCN2_acc = get_cm_of_MLP_with_GCs(2, data, data.x, data.edge_index)
sns.heatmap(data=GCN2_cm,annot=True,fmt=".2f", square=True, cmap="Blues", ax=ax[0][2])
ax[0][2].set_xlabel("Predicted Labels")
ax[0][2].set_ylabel("True Labels")
ax[0][2].set_title(f'GCN2: {GCN2_acc:.4f}')

sns.heatmap(data=GCN1_cm-MLP_cm,vmin=-150,vmax=150,annot=True,fmt=".2f", square=True, cmap="viridis", ax=ax[1][0])
ax[1][0].set_xlabel("Predicted Labels")
ax[1][0].set_ylabel("True Labels")
ax[1][0].set_title(f'GCN1 MLP Diff')
sns.heatmap(data=GCN2_cm-MLP_cm,vmin=-150,vmax=150,annot=True,fmt=".2f", square=True, cmap="viridis", ax=ax[1][1])
ax[1][1].set_xlabel("Predicted Labels")
ax[1][1].set_ylabel("True Labels")
ax[1][1].set_title(f'GCN2 MLP Diff')
sns.heatmap(data=GCN2_cm-GCN1_cm, vmin=-150,vmax=150, annot=True,fmt=".2f", square=True, cmap="viridis", ax=ax[1][2])
ax[1][2].set_xlabel("Predicted Labels")
ax[1][2].set_ylabel("True Labels")
ax[1][2].set_title(f'GCN2 GCN1 Diff')

plt.show()
plt.savefig("amazon_ratings_res.jpg")

GCN2_cm, GCN2_acc = get_cm_of_MLP_with_GCs(3, data, data.x, data.edge_index)
GCN2_cm, GCN2_acc = get_cm_of_MLP_with_GCs(4, data, data.x, data.edge_index)
GCN2_cm, GCN2_acc = get_cm_of_MLP_with_GCs(5, data, data.x, data.edge_index)
