import numpy as np
from tqdm import tqdm
import argparse
import logging
import time

import torch
from torch_geometric.data import NeighborSampler

from data_loading import get_dataset
from data_utils import set_train_val_test_split, get_feature_mask
from models import get_model
from seeds import val_seeds
from filling_strategies import filling
from evaluation import test
from train import train
from reconstruction import spatial_reconstruction

parser = argparse.ArgumentParser('GNN-Missing-Features')
parser.add_argument('--dataset_name', 
                    type=str, 
                    help='Name of dataset', 
                    default="Photo", 
                    choices=["Cora", "CiteSeer", "PubMed", "Photo", "Computers", 
                             "OGBN-Arxiv", "OGBN-Products","Twitch", "Deezer-Europe", "FB100", "Actor", "Syn-Cora", "MixHopSynthetic"])
parser.add_argument('--mask_type', type=str, help='Type of missing feature mask', default="uniform", choices=["uniform", "structural"])
parser.add_argument('--filling_method', type=str, help='Method to solve the missing feature problem', default="dirichlet_diffusion", choices=["random", "zero", "mean", "neighborhood_mean", "learnable", "dirichlet_diffusion", "total_variation_diffusion", "learnable_diffusion"])
parser.add_argument('--model', type=str, help='Prediction model', default="gcn", choices=["mlp", "sgc", "sage", "gcn", "gat", "gcnmf", "pagnn" , "missing_gat", "lp", "cs"])
parser.add_argument('--missing_rate', type=float, help='Rate of nodes with missing features', default=0.1)
parser.add_argument('--patience', type=int, help='Patience for early stopping', default=200)
parser.add_argument('--lr', type=float, help='Learning Rate', default=0.005)
parser.add_argument('--epochs', type=int, help='Max number of epochs', default=10000)
parser.add_argument('--n_runs', type=int, help='Max number of runs', default=10)
parser.add_argument('--hidden_dim', type=int, help='Hidden dimension of model', default=64)
parser.add_argument('--attention_type', type=str, help='attention dimension of model', default="transformer", choices=["transformer", "restricted"])
parser.add_argument('--attention_dim', type=int, help='attention type of model', default=64)
parser.add_argument('--num_layers', type=int, help='Number of GNN layers', default=2)
parser.add_argument('--num_iterations', type=int, help='Number of diffusion iterations for feature reconstruction', default=40)
parser.add_argument('--lp_alpha', type=float, help='Alpha parameter of label propagation', default=0.9)
parser.add_argument('--dropout', type=float, help='Feature dropout', default=0.5)
parser.add_argument('--batch_size', type=int, help='Batch size for models trained with neighborhood sampling', default=1024)
parser.add_argument('--reconstruction_only', action='store_true')
parser.add_argument('--graph_sampling', help='Set if you want to use graph sampling (always true for large graphs)', action='store_true')
parser.add_argument('--log', type=str, help='Log Level', default="WARNING", choices=["DEBUG", "INFO", "WARNING"])
parser.add_argument('--homophily', type=float, help='Level of homophily for Syn-Cora dataset', default=None)
parser.add_argument('--gpu_idx', type=int, help='Indexes of gpu to run program on', default=0)

def run(args):
    graph_sampling = True if args.graph_sampling or args.dataset_name == "OGBN-Products" else False
    if graph_sampling and args.model != "sage":
        print(f"{args.model} model does not support training with neighborhood sampling")

    if args.filling_method == "learnable_diffusion":
        args.num_iterations = 20

    device = torch.device(f'cuda:{args.gpu_idx}' if torch.cuda.is_available() else 'cpu')
    dataset, evaluator = get_dataset(name=args.dataset_name, homophily=args.homophily)

    split_idx = dataset.get_idx_split() if hasattr(dataset, 'get_idx_split') else None
    n_nodes, n_features = dataset.data.x.shape
    test_accs, best_val_accs, relative_reconstruction_errors, train_times = [], [], [], []

    train_loader = NeighborSampler(dataset.data.edge_index, node_idx=split_idx['train'],
                            sizes=[15, 10], batch_size=args.batch_size,
                            shuffle=True, num_workers=12) if graph_sampling else None
    inference_loader = NeighborSampler(dataset.data.edge_index, node_idx=None, sizes=[-1],
                            batch_size=1024, shuffle=False,
                            num_workers=12) if graph_sampling else None

    for seed in tqdm(val_seeds[:args.n_runs]):
        num_classes = dataset.num_classes
        data = set_train_val_test_split(
                    seed=seed,
                    data=dataset.data,
                    split_idx=split_idx,
                    dataset_name=args.dataset_name
                ).to(device)
        train_start = time.time()
        if args.model == "lp":
            assert args.missing_rate == 0.0, "Label Propagation is independent of missing feature rate, so please set it to 0"
            model = get_model(model_name=args.model, num_features=data.num_features, num_classes=num_classes, edge_index=data.edge_index, x=None, args=args).to(device)
            logits = model(y=data.y, edge_index=data.edge_index, mask=data.train_mask)
            (_, val_acc, test_acc), _ = test(model=None, x=None, data=data, logits=logits, evaluator=evaluator)
        else:
            feature_mask = get_feature_mask(rate=args.missing_rate, n_nodes=n_nodes, n_features=n_features, type=args.mask_type).to(device) 
            x = data.x.clone()
            x[~feature_mask] = float('nan')
            
            logger.info("Starting feature filling")
            start = time.time()
            filled_features, lfp = filling(args.filling_method, data.edge_index, x, feature_mask, args.num_iterations, args.attention_dim, args.attention_type) if args.model not in ["gcnmf", "pagnn" , "missing_gat"] else (torch.full_like(x, float('nan')), None)
            logger.info(f"Feature filling completed. It took: {time.time() - start:.2f}s")
            
            relative_reconstruction_errors.append(spatial_reconstruction(data.x, filled_features, feature_mask))
            if args.reconstruction_only:
                continue

            model = get_model(model_name=args.model, num_features=data.num_features, num_classes=num_classes, edge_index=data.edge_index, x=x, mask=feature_mask, args=args).to(device)
            params = list(model.parameters())
            # Add parameters from learnable features (if present)
            params = params + [filled_features] if args.filling_method == "learnable" else params
            # Add parameters from learnable diffusion (if present)
            params = params if lfp is None else params + list(lfp.parameters())
            optimizer = torch.optim.Adam(params, lr=args.lr)
            critereon = torch.nn.NLLLoss()

            test_acc = 0
            val_accs = []
            for epoch in range(0, args.epochs):
                if args.filling_method == 'learnable_diffusion':
                    x = data.x.clone()
                    x[~feature_mask] = float('nan')
                    filled_features = lfp.propagate(x=x, edge_index=data.edge_index, mask=feature_mask)
                
                start = time.time()
                x = torch.where(feature_mask, data.x, filled_features)
                
                train(model, x, data, optimizer, critereon, train_loader=train_loader, device=device)
                (train_acc, val_acc, tmp_test_acc), out = test(model, x=x, data=data, evaluator=evaluator, inference_loader=inference_loader, device=device)
                if epoch == 0 or val_acc > max(val_accs):
                    test_acc = tmp_test_acc
                    y_soft = out.softmax(dim=-1)

                val_accs.append(val_acc)
                if epoch > args.patience and max(val_accs[-args.patience:]) <= max(val_accs[:-args.patience]):
                    break
                logger.info(f"Epoch {epoch + 1} - Train acc: {train_acc:.3f}, Val acc: {val_acc:.3f}, Test acc: {tmp_test_acc:.3f}. It took {time.time() - start:.2f}s")

            (_, val_acc, test_acc), _ = test(model, x=x, data=data, logits=y_soft, evaluator=evaluator)
        best_val_accs.append(val_acc)
        test_accs.append(test_acc)
        train_times.append(time.time() - train_start)

    relative_reconstruction_error_mean, relative_reconstruction_error_std = np.mean(relative_reconstruction_errors), np.std(relative_reconstruction_errors)
    results = {"relative_reconstruction_error_mean": relative_reconstruction_error_mean, "relative_reconstruction_error_std": relative_reconstruction_error_std}
    print(f'Reconstruction error: {relative_reconstruction_error_mean} +- {relative_reconstruction_error_std}')
    
    if not args.reconstruction_only:
        test_acc_mean, test_acc_std = np.mean(test_accs), np.std(test_accs)
        val_acc_mean, val_acc_std = np.mean(best_val_accs), np.std(best_val_accs)
        train_time_mean, train_time_std = np.mean(train_times), np.std(train_times)
        print(f'{test_acc_mean} +- {test_acc_std}')
        results = {**results, **{"test_acc_mean": test_acc_mean, "test_acc_std": test_acc_std, "val_acc_mean": val_acc_mean, "val_acc_std": val_acc_std, "train_time_mean": train_time_mean, "train_time_std": train_time_std}}

if __name__ == "__main__":
    args = parser.parse_args()
    logger = logging.getLogger(__name__)
    logger.setLevel(level=getattr(logging, args.log.upper(), None))
    run(args)
