# coding=utf-8
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch.sparse")
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import wandb
import os
from tqdm import tqdm, trange

from util import load_data, separate_data
from ginpool.graphcnn import GraphCNN

criterion = nn.CrossEntropyLoss()
from torch.utils.data import DataLoader, Dataset

class GraphDataset(Dataset):
    def __init__(self, graphs):
        self.graphs = graphs

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx]

    @staticmethod
    def collate_fn(batch):
        return batch

def train(args, model, device, dataloader, optimizer, epoch):
    model.train()


    loss_accum = 0
    for batch_graph in dataloader:
        output = model(batch_graph)

        labels = torch.LongTensor([graph.label for graph in batch_graph]).to(device)

        # compute loss
        loss = criterion(output, labels)

        # backprop
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        loss = loss.detach().cpu().numpy()
        loss_accum += loss

    average_loss = loss_accum / len(dataloader)

    # Log training loss to wandb
    wandb.log({f"fold_{args.fold_idx}_{args.special_pooling_type}/train_loss": average_loss, "epoch": epoch})

    return average_loss


###pass data to model with minibatch during testing to avoid memory overflow (does not perform backpropagation)
def pass_data_iteratively(model, graphs, minibatch_size=64):
    model.eval()
    output = []
    idx = np.arange(len(graphs))
    for i in range(0, len(graphs), minibatch_size):
        sampled_idx = idx[i:i + minibatch_size]
        if len(sampled_idx) == 0:
            continue
        output.append(model([graphs[j] for j in sampled_idx]).detach())
    return torch.cat(output, 0)


def test(args, model, device, train_graphs, test_graphs, epoch):
    model.eval()

    output = pass_data_iteratively(model, train_graphs)
    pred = output.max(1, keepdim=True)[1]
    labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device)
    correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
    acc_train = correct / float(len(train_graphs))

    output = pass_data_iteratively(model, test_graphs)
    pred = output.max(1, keepdim=True)[1]
    labels = torch.LongTensor([graph.label for graph in test_graphs]).to(device)
    correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
    acc_test = correct / float(len(test_graphs))

    # Calculate test loss
    test_loss = criterion(output, labels).item()

    # tqdm.write("accuracy train: %f test: %f" % (acc_train, acc_test))
    # tqdm.write("loss test: %f" % (test_loss))

    # Log test accuracy and loss to wandb
    wandb.log({f"fold_{args.fold_idx}_{args.special_pooling_type}/acc_train": acc_train, f"fold_{args.fold_idx}_{args.special_pooling_type}/acc_test": acc_test, f"fold_{args.fold_idx}_{args.special_pooling_type}/test_loss": test_loss, "epoch": epoch})

    return acc_train, acc_test


def main():
    # Training settings
    # Note: Hyper-parameters need to be tuned in order to obtain results reported in the paper.
    parser = argparse.ArgumentParser(
        description='PyTorch graph convolutional neural net for whole-graph classification')
    parser.add_argument('--dataset', type=str, default="MUTAG",
                        help='name of dataset (default: MUTAG)')
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--iters_per_epoch', type=int, default=50,
                        help='number of iterations per each epoch (default: 50)')
    parser.add_argument('--epochs', type=int, default=350,
                        help='number of epochs to train (default: 500)')
    parser.add_argument('--lr', type=float, default=0.02,
                        help='learning rate (default: 0.01)')
    parser.add_argument('--seed', type=int, default=0,
                        help='random seed for splitting the dataset into 10 (default: 0)')
    parser.add_argument('--num_layers', type=int, default=5,
                        help='number of layers INCLUDING the input one (default: 5)')
    parser.add_argument('--num_mlp_layers', type=int, default=2,
                        help='number of layers for MLP EXCLUDING the input one (default: 2). 1 means linear model.')
    parser.add_argument('--hidden_dim', type=int, default=16,
                        help='number of hidden units (default: 16)')
    parser.add_argument('--final_dropout', type=float, default=0.5,
                        help='final layer dropout (default: 0.5)')
    parser.add_argument('--graph_pooling_type', type=str, default="sum", choices=["sum", "average"],
                        help='Pooling for over nodes in a graph: sum or average')
    parser.add_argument('--special_pooling_type', type=str, default="temporal", choices=["temporal", "supra", "mean", "max"],
                        help='Pooling for over nodes in a graph: temporal or supra')
    parser.add_argument('--neighbor_pooling_type', type=str, default="sum", choices=["sum", "average", "max"],
                        help='Pooling for over neighboring nodes: sum, average or max')
    parser.add_argument('--learn_eps', action="store_true",
                        help='Whether to learn the epsilon weighting for the center nodes. Does not affect training accuracy though.')
    parser.add_argument('--degree_as_tag', action="store_true",
                        help='let the input node features be the degree of nodes (heuristics for unlabeled graph)')
    parser.add_argument('--filename', type=str, default="10flod.txt",
                        help='output file')
    
    # Temporal Global Pooling parameters
    parser.add_argument('--use_fc', type=str, default='false',
                        help='use fully connected layer in temporal global pooling')
    parser.add_argument('--mha_dropout', type=float, default=0.1,
                        help='dropout rate for multi-head attention (default: 0.1)')
    parser.add_argument('--num_head', type=int, default=8,
                        help='number of attention heads (default: 8)')
    parser.add_argument('--add_zero_attn', type=str, default='false',
                        help='add a zero attention token to the temporal attention')
    parser.add_argument('--attn_mask_dropout', type=str, default='false',
                        help='if true the mha_dropout is applied to the attention mask')
    parser.add_argument('--alpha_type', type=str, default='learnable',
                        help='if true the mha_dropout is applied to the attention mask')
    parser.add_argument('--weight_decay', type=float, default=0.0,
                        help='weight decay (default: 0.0)')
    parser.add_argument('--use_layer_norm', type=str, default='false',
                        help='if true the mha_dropout is applied to the attention mask')
    parser.add_argument('--skip_connection', type=str, default='false',
                        help='if true the mha_dropout is applied to the attention mask')
    parser.add_argument('--task', type=str, default='graph_classification')
    parser.add_argument('--learn_query', action='store_true',
                        help='Whether to learn the query for the temporal attention')
    # Convert string arguments to boolean
    args = parser.parse_args()
    args.use_fc = args.use_fc.lower() == 'true'
    args.add_zero_attn = args.add_zero_attn.lower() == 'true'
    args.attn_mask_dropout = args.attn_mask_dropout.lower() == 'true'
    args.alpha_type = args.alpha_type.lower()
    args.use_layer_norm = args.use_layer_norm.lower() == 'true'
    args.skip_connection = args.skip_connection.lower() == 'true'
    
    if args.dataset in ["NCI1", "REDDITBINARY"]:
        args.epochs = 500

    # Initialize wandb
    experiment_name = f"{args.dataset}_bs{args.batch_size}_ep{args.epochs}_lr{args.lr}_hd{args.hidden_dim}_{args.special_pooling_type}"
    wandb.init(project="Histograph", config=args, name=experiment_name)

    all_acc_test = []
    graphs, num_classes = load_data(args.dataset, args.degree_as_tag)
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    for graph in graphs:
        graph.node_features = graph.node_features.to(device)
        graph.edge_mat = graph.edge_mat.to(device)
    
    for fold_idx in trange(10, desc="Folds", leave=False):
        args.fold_idx = fold_idx
        # set up seeds and gpu device
        torch.manual_seed(0)
        np.random.seed(0)
        torch.set_num_threads(10) 
        
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        ##10-fold cross validation. Conduct an experiment on the fold specified by args.fold_idx.
        train_graphs, test_graphs = separate_data(graphs, args.seed, args.fold_idx)
        args.num_classes = num_classes
        args.device = device
        args.input_dim = train_graphs[0].node_features.shape[1]
        args.output_dim = num_classes
        model = GraphCNN(args).to(device)
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        max_acc = 0.0
        train_dataset = GraphDataset(train_graphs)
        train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=GraphDataset.collate_fn)
        for epoch in (pbar:=trange(1, args.epochs + 1, desc="Epochs", leave=True)):
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            start_event.record()
            avg_loss = train(args, model, device, train_dataloader, optimizer, epoch)
            acc_train, acc_test = test(args, model, device, train_graphs, test_graphs, epoch)
            end_event.record()

            # Waits for everything to finish running
            torch.cuda.synchronize()

            epoch_time = start_event.elapsed_time(end_event) / 1000  # Convert milliseconds to seconds

            max_acc = max(max_acc, acc_test)
            pbar.set_description(f"Epoch {epoch} took {epoch_time:.2f} seconds, max acc: {max_acc:.2f}")
            if max_acc == 1:
                break
        all_acc_test.append(max_acc)
        os.makedirs('result', exist_ok=True)
        with open(str('result/') + str(args.dataset) + '_' + str(args.batch_size) + '_' + str(args.epochs) + '_' + str(args.lr) +'_'+ str(args.hidden_dim )+ '_'+ str(args.special_pooling_type) + '_results.txt', 'a+') as f:
            f.write(str(max_acc) + '\n')
    
        avg_acc = np.mean(all_acc_test)
        std_acc = np.std(all_acc_test)
        wandb.log({f"summary/{args.special_pooling_type}/average_acc": avg_acc , f"summary/{args.special_pooling_type}/std_acc": std_acc})
    wandb.log({f'summary/{args.special_pooling_type}/final_acc': avg_acc})
    wandb.finish()

if __name__ == '__main__':
    main()
