import os
import sys
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

import torch
from torch.nn import Linear
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.logging import log
from torch_geometric.nn import GCNConv
from torch_geometric.utils import degree, remove_self_loops, add_self_loops, dense_to_sparse

from layer import Propagate
from dataset_utils import DataLoader
from utils import random_class_balance_splits
import argparse
import numpy as np

parser = argparse.ArgumentParser(description='General Training Pipeline')

parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--num_gc', type=int, default=1)
parser.add_argument('--epochs', type=int, default=1000)
args = parser.parse_args()
print(args)

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout):
        super().__init__()
        self.lin1 = Linear(in_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, out_channels)
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin1(x)
        x = self.bn1(x).relu()
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        for i in range(args.num_gc):
            x = Propagate()(x, edge_index)
        return x
    

def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def test(model, data):
    model.eval()
    pred = model(data.x, data.edge_index).argmax(dim=-1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs, pred

dataset, data = DataLoader(args.dataset)
print(dataset)
print(data)
y_one_hot = torch.zeros(data.y.shape[0], dataset.num_classes).scatter_(1, data.y.view(-1, 1), 1.0)

train_rate = 0.6
val_rate = 0.2
# number of train samples per class (averaged)
percls_trn = int(round(train_rate*len(data.y)/dataset.num_classes))
# number of validation samples
val_lb = int(round(val_rate*len(data.y)))

data = random_class_balance_splits(data, dataset.num_classes, train_rate, val_rate)

hidden_s = [16, 32, 64, 128, 256]
lr_s = [0.001, 0.005, 0.01]
wd_s = [0, 1e-5, 5e-4, 1e-4]
dropout_s = [0, 0.2, 0.5]
    

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


best_test_acc = -1
best_cm = None
for hidden in hidden_s:
    for lr in lr_s:
        for wd in wd_s:
            for dropout in dropout_s:
                model = GCN(dataset.num_features, hidden, dataset.num_classes, dropout)
                model, data = model.to(device), data.to(device)
                optimizer = torch.optim.Adam([
                    dict(params=model.parameters(), weight_decay=wd),
                ], lr=lr)
                best_val_acc = final_test_acc = 0
                for epoch in range(1, args.epochs + 1):
                    loss = train(model, optimizer, data)
                    accs, pred = test(model, data)
                    train_acc, val_acc, tmp_test_acc = accs
                    if val_acc > best_val_acc:
                        best_val_acc = val_acc
                        final_test_acc = tmp_test_acc
                if best_test_acc < final_test_acc:
                    best_test_acc = final_test_acc
                    best_cm = confusion_matrix(data.y.cpu()[data.test_mask], pred.cpu()[data.test_mask])
                    print("Update Results!")
                    np.savez(f'results/{args.dataset}-gc{args.num_gc}.npz', cm=best_cm, acc=best_test_acc)
                logs = f"hd:{hidden}," + f"lr:{lr}," + f"wd:{wd}," + f"dp:{dropout}," + f"Val:{best_val_acc:.3f}," + f"Test:{final_test_acc:.3f}"
                print(logs)
                filename = f'results/{args.dataset}-gc{args.num_gc}.csv'
                # print(f"Saving results to {filename}")
                with open(f"{filename}", 'a+') as write_obj:
                    write_obj.write(logs + '\n')

