import gc
import os
import argparse

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch_geometric.data import Data

from laplace.curvature import AsdlGGN

from models.opt_gnn import MPNNs
from utils import setup_logger, unused_gpu, edge_homophily, get_graph_prior
from parse import parse_add_args, LAPLACE_CLASSES
from data_utils import load_data
from marglik import marglik_optimization


def main():
    parser = argparse.ArgumentParser(
        "Marglik optimization on graphs with GCN based on Luo, 2024.")
    parse_add_args(parser)
    parser.add_argument('--splits', type=int, nargs='+', default=None,
                        help='List of splits to use. Provide as space-separated integers, e.g., --splits 0 1 2')
    args = parser.parse_args()

    log_dir = f"logs/{args.dataset}/{args.model}/{args.laplace}"
    os.makedirs(log_dir, exist_ok=True)

    log_file = f"{log_dir}/{args.job_id}.log"
    logger = setup_logger(log_file, console=args.console)

    # log hyperparameters
    logger.info("Hyperparameters:")
    for arg, value in vars(args).items():
        logger.info(f"HYPERPARAM: {arg}={value}")

    if torch.cuda.is_available():
        gpu_num = args.gpu or unused_gpu()
        device = torch.device(f"cuda:{gpu_num}" if gpu_num is not None else "cuda")
    else:
        device = torch.device("cpu")
    
    # dataset
    if args.dataset in ['cora', 'citeseer', 'pubmed']:
        npz = np.load(f'../../data_opt/{args.dataset}.npz')
        x = torch.tensor(npz['x'], dtype=torch.float)
        edge_index = torch.tensor(npz['edge_index'], dtype=torch.long)
        y = torch.tensor(npz['y'], dtype=torch.long)
        train_mask = torch.zeros((x.size(0), 1), dtype=torch.bool)
        train_mask[npz['train'], 0] = True
        val_mask = torch.zeros((x.size(0), 1), dtype=torch.bool)
        val_mask[npz['valid'], 0] = True
        test_mask = torch.zeros((x.size(0), 1), dtype=torch.bool)
        test_mask[npz['test'], 0] = True
        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
    else:
        data = load_data(args.dataset, n_rand_splits=0,
                     feat_norm=args.feat_norm)
    data = data.to(device)

    dataset = TensorDataset(data.x, data.y)
    data_loader = DataLoader(
        dataset, batch_size=data.x.size(0), shuffle=False)

    valid_losses, test_losses = [], []
    valid_accs, test_accs, log_margliks, n_edges = [], [], [], []
    
    # get graph prior based on observed edges (and optionally kNN)
    prior_edge_index, prior_edge_probs = get_graph_prior(
        obs_edge_index=data.edge_index,
        obs_edge_prob=args.obs_prior_edge_prob,
        non_edge_prob=args.prior_non_edge_prob,
        n_nodes=data.x.size(0),
        knn_prior_edge_k=args.knn_prior_edge_k,
        x=data.x,
        knn_prior_edge_prob=args.knn_prior_edge_prob,
        knn_prior_edge_dist_metric=args.knn_prior_edge_dist_metric)
    
    splits = args.splits if args.splits is not None else range(data.train_mask.size(1))
    for i in range(args.n_repeats):
        for j in splits:
            logger.info("-" * 20 + f"Split {j+1}/{len(splits) + min(splits)}." + "-" * 20, extra={'repeat': i})

            train_mask = data.train_mask[:, j]
            valid_mask = data.val_mask[:, j]
            test_mask = data.test_mask[:, j]
    
            model = MPNNs(
                in_channels=data.x.size(-1),
                hidden_channels=args.hidden_channels,
                out_channels=data.y.unique().size(-1),
                num_layers=args.num_layers,
                dropout=args.dropout,
                act=args.act,
                ln=args.norm == 'layer',
                bn=args.norm == 'batch',
                jk=args.jk,
                res=args.res,
                graph_builder_kwargs={'prior_edge_index': prior_edge_index,
                                    'prior_edge_probs': prior_edge_probs,
                                    'n_nodes': data.x.size(0),},
                bin_conc_kwargs={'temperature': args.cont_relax_temp},
                gnn=args.model,
                )
            model.to(device)
            
            # learn graph
            _, _, best_model_stats = marglik_optimization(
                model=model,
                train_loader=data_loader,
                valid_loader=data_loader,
                train_mask=train_mask,
                valid_mask=valid_mask,
                test_mask=test_mask,
                lr=args.lr,
                weight_decay=args.weight_decay,
                n_epochs=args.n_epochs,
                n_epochs_burnin=args.n_epochs_burnin,
                n_hypersteps=args.n_hypersteps,
                marglik_frequency=args.marglik_frequency,
                laplace=LAPLACE_CLASSES[args.laplace.lower()],
                backend=AsdlGGN,
                lr_graph=args.lr_graph,
                graph_grad_norm=args.graph_grad_norm,
                lr_graph_min=args.lr_graph_min,
                early_stop_crit=args.early_stop_crit,
                graph_prior=args.graph_prior,
                prior_edge_index=prior_edge_index,
                prior_edge_probs=prior_edge_probs,
                graph_kl_weight=args.graph_kl_weight,
                n_samples=args.n_samples,
                log_frequency=args.log_frequency,
                log_det_weight=args.log_det_weight,
                checkpoint_dir=args.checkpoint_dir,
                repeat=i,
            )

            valid_loss = best_model_stats['valid_loss'] 
            valid_acc = best_model_stats['valid_acc']
            test_loss = best_model_stats['test_loss']
            test_acc = best_model_stats['test_acc']
            valid_losses.append(valid_loss)
            valid_accs.append(valid_acc)
            test_losses.append(test_loss)
            test_accs.append(test_acc)

            logger.info(f"VALID LOSS: {valid_loss:.4f}", extra={'repeat': i})
            logger.info(f"VALID ACC: {valid_acc*100:.2f}%", extra={'repeat': i})
            logger.info(f"TEST LOSS: {test_loss:.4f}", extra={'repeat': i})
            logger.info(f"TEST ACC: {test_acc*100:.2f}%", extra={'repeat': i})

            # approx num edges
            discrete_adj = model.get_discrete_graph().detach().cpu()
            _n_edges = discrete_adj.sum().item()
            homophily = edge_homophily(discrete_adj, data.y)
            logger.info(f"HOMOPHILY: {homophily:.4f}", extra={'repeat': i})
            logger.info(f"NUM EDGES: {_n_edges}.", extra={'repeat': i})
            n_edges.append(_n_edges)
            
            del model
            gc.collect()
            logger.info(f"test: {test_accs} ({np.mean(test_accs):.4f})", extra={'repeat': i})

    # average across repeats
    logger.info(f"Mean valid loss: {np.mean(valid_losses):.6f} ({np.std(valid_losses):.6f})")
    logger.info(f"Mean test loss: {np.mean(test_losses):.6f} ({np.std(test_losses):.6f})")
    logger.info(f"Mean valid accuracy: {np.mean(valid_accs)*100:.2f}% ({np.std(valid_accs)*100:.2f})")
    logger.info(f"Mean test accuracy: {np.mean(test_accs)*100:.2f}% ({np.std(test_accs)*100:.2f})")
    logger.info(f"Number of edges: {int(np.mean(n_edges))} ({np.std(n_edges):.2f})")


if __name__ == "__main__":
    main()
