import os
import argparse
import time
import psutil
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

from laplace.curvature import AsdlGGN

from models import GCN, GraphSAGE
from utils import setup_logger, unused_gpu, edge_homophily, get_graph_prior
from parse import parse_add_args, LAPLACE_CLASSES
from data_utils import load_data
from marglik import marglik_optimization


def main():
    parser = argparse.ArgumentParser("Marglik optimization on graphs.")
    parse_add_args(parser)
    args = parser.parse_args()

    log_dir = f"logs/{args.dataset}/{args.model}/{args.laplace}"
    os.makedirs(log_dir, exist_ok=True)

    log_file = f"{log_dir}/{args.job_id}.log"
    logger = setup_logger(log_file, console=args.console)

    # log hyperparameters
    logger.info("Hyperparameters:")
    for arg, value in vars(args).items():
        logger.info(f"HYPERPARAM: {arg}={value}")

    if torch.cuda.is_available():
        gpu_num = args.gpu or unused_gpu()
        device = torch.device(f"cuda:{gpu_num}" if gpu_num is not None else "cuda")
    else:
        device = torch.device("cpu")

    data = load_data(args.dataset, n_rand_splits=0,
                     feat_norm=args.feat_norm)
    data = data.to(device)

    if args.early_stop_crit is None:
        args.early_stop_crit = 'valid'
        logger.info(f"Early stopping criterion: {args.early_stop_crit}.")

    valid_losses, test_losses = [], []
    valid_accs, test_accs, log_margliks, n_edges = [], [], [], []
    
    # get graph prior based on observed edges (and optionally kNN)
    prior_edge_index, prior_edge_probs = get_graph_prior(
        obs_edge_index=data.edge_index,
        obs_edge_prob=args.obs_prior_edge_prob,
        non_edge_prob=args.prior_non_edge_prob,
        n_nodes=data.x.size(0),
        knn_prior_edge_k=args.knn_prior_edge_k,
        x=data.x,
        knn_prior_edge_prob=args.knn_prior_edge_prob,
        knn_prior_edge_dist_metric=args.knn_prior_edge_dist_metric)
    
    splits = args.split if args.split is not None else range(data.train_mask.size(1))
    for j in splits:
        for i in range(args.n_repeats):
            logger.info(f"Split {j+1}/{data.train_mask.size(1)}.", extra={'repeat': i})
            train_mask = data.train_mask[:, j]
            valid_mask = data.val_mask[:, j]
            test_mask = data.test_mask[:, j]

            dataset = TensorDataset(data.x, data.y)
            data_loader = DataLoader(
                dataset, batch_size=data.x.size(0), shuffle=False)

            # initialize model
            if args.model == 'gcn':
                model_type = GCN
            elif args.model == 'graphsage':
                model_type = GraphSAGE
            else:
                raise ValueError(f"Unknown model type: {args.model}")
            
            model = model_type(
                in_channels=data.x.size(-1),
                hidden_channels=args.hidden_channels,
                num_layers=args.num_layers,
                out_channels=data.y.unique().size(-1),
                dropout=args.dropout,
                act=args.act,
                norm=args.norm,
                jk=args.jk,
                res=args.res,
                graph_builder_kwargs={
                    'prior_edge_index': prior_edge_index,
                    'prior_edge_probs': prior_edge_probs,
                    'n_nodes': data.x.size(0),},
                bin_conc_kwargs={'temperature': args.cont_relax_temp},
                )
            model.to(device)

            # Clear CUDA memory
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats(device)

            # Start time and memory tracking
            start_time = time.time()
            process = psutil.Process(os.getpid())


            # learn graph
            la, model, best_model_stats = marglik_optimization(
                model=model,
                train_loader=data_loader,
                valid_loader=data_loader,
                train_mask=train_mask,
                valid_mask=valid_mask,
                test_mask=test_mask,
                lr=args.lr,
                weight_decay=args.weight_decay,
                n_epochs=args.n_epochs,
                n_epochs_burnin=args.n_epochs_burnin,
                n_hypersteps=args.n_hypersteps,
                marglik_frequency=args.marglik_frequency,
                laplace=LAPLACE_CLASSES[args.laplace.lower()],
                backend=AsdlGGN,
                lr_graph=args.lr_graph,
                graph_grad_norm=args.graph_grad_norm,
                lr_graph_min=args.lr_graph_min,
                early_stop_crit=args.early_stop_crit,
                graph_prior=args.graph_prior,
                prior_edge_index=prior_edge_index,
                prior_edge_probs=prior_edge_probs,
                graph_kl_weight=args.graph_kl_weight,
                n_samples=args.n_samples,
                log_frequency=args.log_frequency,
                log_det_weight=args.log_det_weight,
                checkpoint_dir=args.checkpoint_dir,
                repeat=i,
            )
            end_time = time.time()

            # Get memory usage
            cpu_mem_mb = process.memory_info().rss / 1024 ** 2
            gpu_mem_mb = torch.cuda.max_memory_allocated(device) / 1024 ** 2

            # Print results
            total_time = end_time - start_time
            time_per_iter = total_time / (args.n_epochs + ((args.n_epochs - args.n_epochs_burnin) // args.marglik_frequency) * args.n_hypersteps)
            logger.info(f"Training time: {total_time:.4f} sec", extra={'repeat': i})
            logger.info(f"Time per iteration: {time_per_iter:.6f} sec", extra={'repeat': i})
            logger.info(f"Peak GPU memory: {gpu_mem_mb:.2f} MB", extra={'repeat': i})
            logger.info(f"CPU memory usage: {cpu_mem_mb:.2f} MB", extra={'repeat': i})


            valid_loss = best_model_stats['valid_loss'] 
            valid_acc = best_model_stats['valid_acc']
            test_loss = best_model_stats['test_loss']
            test_acc = best_model_stats['test_acc']
            valid_losses.append(valid_loss)
            valid_accs.append(valid_acc)
            test_losses.append(test_loss)
            test_accs.append(test_acc)

            logger.info(f"VALID LOSS: {valid_loss:.4f}", extra={'repeat': i})
            logger.info(f"VALID ACC: {valid_acc*100:.2f}%", extra={'repeat': i})
            logger.info(f"TEST LOSS: {test_loss:.4f}", extra={'repeat': i})
            logger.info(f"TEST ACC: {test_acc*100:.2f}%", extra={'repeat': i})

            # approx num edges
            discrete_adj = model.get_discrete_graph().detach().cpu()
            _n_edges = discrete_adj.sum().item()
            homophily = edge_homophily(discrete_adj, data.y)
            logger.info(f"HOMOPHILY: {homophily:.4f}", extra={'repeat': i})
            logger.info(f"NUM EDGES: {_n_edges}.", extra={'repeat': i})
            n_edges.append(_n_edges)

            # marglik
            log_lik, log_det = la.log_marginal_likelihood()
            marglik = (log_lik - log_det).item()
            logger.info(f"MARGLIK: {marglik:.2f}", extra={'repeat': i})
            log_margliks.append(marglik)

            logger.info(f"Test accs: {test_accs}")

    # average across repeats
    logger.info(f"Mean valid loss: {np.mean(valid_losses):.6f} ({np.std(valid_losses):.6f})")
    logger.info(f"Mean test loss: {np.mean(test_losses):.6f} ({np.std(test_losses):.6f})")
    logger.info(f"Mean valid accuracy: {np.mean(valid_accs)*100:.2f}% ({np.std(valid_accs)*100:.2f})")
    logger.info(f"Mean test accuracy: {np.mean(test_accs)*100:.2f}% ({np.std(test_accs)*100:.2f})")
    logger.info(f"Mean marginal likelihood: {np.mean(log_margliks):.4f} ({np.std(log_margliks):.4f})")
    logger.info(f"Number of edges: {int(np.mean(n_edges))} ({np.std(n_edges):.2f})")


if __name__ == "__main__":
    main()
