import argparse
import subprocess

import torch
import wandb

from sound_rule_extraction import find_weight_cutoff_for_ratio_rule_channels, model_stats, nabn, up_down, UpDownStates, \
    neg_inf_line, check_given_rules, RuleCaptureStates, check_all_rules, model_weight_distribution
import gnn_architectures
from model_sparsity import weight_cutoff_model
from monotonic_rule_extraction import check_monotonic_rules, explain_predictions

link_prediction_datasets = [
    'WN18RRv1',
    'WN18RRv2',
    'WN18RRv3',
    'WN18RRv4',
    'fb237v1',
    'fb237v2',
    'fb237v3',
    'fb237v4',
    'nellv1',
    'nellv2',
    'nellv3',
    'nellv4',
    # 'grail',  # exclude for now, since not established benchmark and file structure is different
    # 'kinship',  # exclude for now, different file structure
]

node_classification_datasets = [
    'aifb',
    'mutag',
    'lubm',
]

log_infer_datasets = [
    'LogInfer-FB',
    'LogInfer-WN',
]

log_infer_patterns = [
    # 'comp',  # not tree-like, cannot be checked. Triangle and diamond are also not tree-like
    'hier',
    'inter',  # inter not supported for LogInfer-WN
    'inver',
    'sym',
    'fork',  # Only for WN so far (FB generation taking too much memory)
    'cup',
    'nmhier',  # Only for WN so far, non-monotonic dataset
    'hier_nmhier',
    'cup_nmhier',
    'superhier',
]

log_infer_datasets = [dataset + '-' + pattern for dataset in log_infer_datasets for pattern in log_infer_patterns]

negative_sampling_methods = ['rb', 'rc', 'pc', 'nm', 'pc_nm']

parser = argparse.ArgumentParser(description="Main file for running experiments")

# Model
parser.add_argument('--dataset',
                    choices=link_prediction_datasets + node_classification_datasets + log_infer_datasets,
                    help='Name of the dataset')
parser.add_argument('--layers',
                    default=2,
                    type=int,
                    help='Number of layers in the model')
parser.add_argument('--seed',
                    default=-1,  # -1 seed means seed will be chosen at random
                    type=int,
                    help='Seed used to init RNG')
parser.add_argument('--aggregation',
                    default='mean',
                    choices=['max', 'sum', 'mean'],
                    help='Aggregation function to be used by the model')

# Training
parser.add_argument('--early-stop',
                    default=50,  # -1 means no early stopping
                    type=int,
                    help='Number of epochs with worse loss than best epoch before early stopping')
parser.add_argument('--lr',
                    default=0.01,
                    type=float,
                    help='Learning rate')
parser.add_argument('--epochs',
                    default=10000,
                    type=int,
                    help='Number of epochs to train for')
parser.add_argument('--checkpoint-interval',
                    default=9999999,
                    type=int,
                    help='How many epochs between model checkpoints')
parser.add_argument('--train',
                    type=int,
                    choices=[0, 1],
                    default=0,
                    help='If 0, the script will not train a new model, but fetch an existing trained model')
parser.add_argument('--non-negative-weights',
                    type=int,
                    choices=[0, 1],
                    default=0,
                    help='Restrict matrix weights during training so that they are all non-negative')
parser.add_argument('--weight-clamping-interval',
                    default=-1,
                    type=int,
                    help='Number of epochs between weight clamping for rule channels. -1 for no weight clamping.')

# Testing
parser.add_argument('--test',
                    type=int,
                    choices=[0, 1],
                    default=0,
                    help='If 0, the script will not test the model, merely train it')
parser.add_argument('--evaluation-set',
                    default='valid',
                    choices=['valid', 'test'],
                    help='Whether you should evaluate on the validation or test set')
parser.add_argument('--negative-sampling-method',
                    default='pc',
                    choices=negative_sampling_methods,
                    help='Negative sampling method for evaluation')
parser.add_argument('--rule-channels-min-ratio',
                    type=float,
                    default=-1,
                    help='Weight cutoff will be chosen to give a number of channels corresponding to rules'
                         'strictly greater than the ratio given, which should be in [0, 1).'
                         'Such channels are either UP or 0 (i.e. monotonic increasing or do not depend on input).'
                         'If -1, then no weight cutoff is used.')

# Rule extraction
parser.add_argument('--extract',
                    type=int,
                    choices=[0, 1],
                    default=0,
                    help='Run and log the outputs of the rule extraction algorithms?')
parser.add_argument('--log-infer-rule-check',
                    type=int,
                    choices=[0, 1],
                    default=0,
                    help='Check if LogInfer rules are captured by the model.')
parser.add_argument('--search-rule-check',
                    type=int,
                    choices=[0, 1],
                    default=0,
                    help='Search space of possible rules to see which ones are captured.')
parser.add_argument('--reduce-rule-size',
                    type=int,
                    choices=[0, 1],
                    default=0,
                    help='Reduce the size of rules used for explanations, using a greedy algorithm.')

# Logging
parser.add_argument('--use-wandb',
                    type=int,
                    choices=[0, 1],
                    default=0,
                    help='Log to wandb?')
parser.add_argument('--log-interval',
                    default=1,
                    type=int,
                    help='How many epochs between model logs')

args = parser.parse_args()

# init logging
if args.use_wandb:
    wandb.init(project='mean-gnns')

#
# TRAINING
#

model_name = f'{args.dataset}_layers_{args.layers}_agg_{args.aggregation}_epochs_{args.epochs}_lr_{args.lr}_seed_{args.seed}'
model_folder = '../models'
encoder_folder = '../encoders'

train_graph, train_examples, predicates, train_file_full = None, None, None, None

if args.dataset in node_classification_datasets:
    encoding_scheme = 'canonical'
    path_to_dataset = f'../data/node_classification/{args.dataset}'
    train_graph = f'{path_to_dataset}/graph.nt'
    if args.dataset == 'lubm':
        train_graph = f'{path_to_dataset}/train_input.tsv'
    train_examples = f'{path_to_dataset}/train.tsv'
    predicates = f'{path_to_dataset}/predicates.csv'
elif args.dataset in link_prediction_datasets:
    encoding_scheme = 'iclr22'
    path_to_dataset = f'../data/link_prediction/{args.dataset}'
    train_graph = f'{path_to_dataset}/train_graph.tsv'
    train_examples = f'{path_to_dataset}/train_pos.tsv'
    predicates = f'{path_to_dataset}/predicates.csv'
elif args.dataset in log_infer_datasets:
    assert not (args.dataset == 'LogInfer-WN-inter'),\
        'LogInfer pattern "inter" not supported for dataset "LogInfer-WN"'
    encoding_scheme = 'iclr22'
    path_to_dataset = f'../data/LogInfer/LogInfer-benchmark/{args.dataset}'
    train_file_full = f'{path_to_dataset}/train.txt'
    # note that negative dataset not specified in model, since it does not affect training
    # TODO: however, it does affect hyperparameter tuning
    #  so include in name if want to use different negative sets during training
else:
    assert False, f'Dataset "{args.dataset}" not recognized'

if args.non_negative_weights:
    model_name = model_name + '_non_negative_weights'

if args.weight_clamping_interval != -1:
    model_name = model_name + '_rule_channels_' + str(args.rule_channels_min_ratio)\
                 + '_clamp_interval_' + str(args.weight_clamping_interval)

train_command = [
    'python',
    'train.py',
    '--model-name', model_name,
    '--model-folder', model_folder,
    '--encoding-scheme', encoding_scheme,
    '--encoder-folder', encoder_folder,
    '--non-negative-weights', str(args.non_negative_weights),
    '--layers', str(args.layers),
    '--early-stop', str(args.early_stop),
    '--lr', str(args.lr),
    '--seed', str(args.seed),
    '--epochs', str(args.epochs),
    '--checkpoint-interval', str(args.checkpoint_interval),
    '--log-interval', str(args.log_interval),
    '--use-wandb', str(args.use_wandb),
    '--weight-clamping-interval', str(args.weight_clamping_interval),
    '--rule-channels-min-ratio', str(args.rule_channels_min_ratio),
    '--aggregation', str(args.aggregation),
]

if args.dataset in log_infer_datasets:
    train_command = train_command + [
        '--train-file-full', train_file_full
    ]
else:
    train_command = train_command + [
        '--train-graph', train_graph,
        '--train-examples', train_examples,
        '--predicates', predicates,
    ]

if args.train:
    print('Training...')
    print('Running command:', train_command)
    subprocess.run(train_command)

#
# TESTING
#
load_model_name = f'{model_folder}/{model_name}.pt'

if args.dataset in node_classification_datasets:
    test_graph = train_graph
    test_positive_examples = f'{path_to_dataset}/{args.evaluation_set}_pos.tsv'
    test_negative_examples = f'{path_to_dataset}/{args.evaluation_set}_neg.tsv'
elif args.dataset in link_prediction_datasets:
    test_graph = f'{path_to_dataset}/{args.evaluation_set}_graph.tsv'  # Different graph given as input for testing
    test_positive_examples = f'{path_to_dataset}/{args.evaluation_set}_pos.tsv'
    test_negative_examples = f'{path_to_dataset}/{args.evaluation_set}_neg.tsv'
else:  # log_infer_datasets:
    test_graph = train_file_full
    test_positive_examples = f'{path_to_dataset}/{args.evaluation_set}.txt'
    test_negative_examples = f'{path_to_dataset}/{args.evaluation_set}_neg_{args.negative_sampling_method}.txt'
output = f'../metrics/{model_name}_rule_channels_{args.rule_channels_min_ratio}.txt'
canonical_encoder_file = f'../encoders/{model_name}_canonical.tsv'
iclr22_encoder_file = f'../encoders/{model_name}_iclr22.tsv'

test_command = [
    'python',
    'test.py',
    '--load-model-name', load_model_name,
    '--test-graph', test_graph,
    '--test-positive-examples', test_positive_examples,
    '--test-negative-examples', test_negative_examples,
    '--output', output,
    '--encoding-scheme', encoding_scheme,
    '--canonical-encoder-file', canonical_encoder_file,
    '--iclr22-encoder-file', iclr22_encoder_file,
    '--use-wandb', str(args.use_wandb),
    '--eval-threshold-key', str(args.rule_channels_min_ratio),
]

if args.evaluation_set == 'valid':
    test_command = test_command + [
        '--set-optimal-threshold', '1',
    ]
elif args.evaluation_set == 'test':
    test_command = test_command + [
        '--use-optimal-threshold', '1',
    ]


def get_model_and_weight_cutoff():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model: gnn_architectures.GNN = torch.load(load_model_name, weights_only=False).to(device)
    weight_cutoff = -1

    if args.rule_channels_min_ratio != -1:
        print(f'Searching for a weight cutoff to obtain a ratio of >{args.rule_channels_min_ratio} rule channels')
        weight_cutoff, _, _ = find_weight_cutoff_for_ratio_rule_channels(
            model,
            args.rule_channels_min_ratio,
        )
        print(f'Cutoff {weight_cutoff} found')
        if args.use_wandb:
            wandb.log({
                'weight_cutoff': weight_cutoff,
            })

    return model, weight_cutoff


if args.test:
    print('Testing...')

    if args.rule_channels_min_ratio != -1:
        _, weight_cutoff = get_model_and_weight_cutoff()
        test_command = test_command + [
            '--weight-cutoff', str(weight_cutoff),
        ]

    print('Running command:', test_command)
    subprocess.run(test_command)

if args.extract:
    print('Running model stat extraction algorithms...')

    model, weight_cutoff = get_model_and_weight_cutoff()
    if args.rule_channels_min_ratio != -1:
        weight_cutoff_model(model, weight_cutoff)

    print('-----\nModel stats:')
    positive_weights, negative_weights, zero_weights = model_stats(model)

    print('-----\nModel weight distribution:')
    mean, std, p0, p25, p50, p75, p100 = model_weight_distribution(model)
    print(mean, std, p0, p25, p50, p75, p100)

    if args.use_wandb:
        wandb.log({
            'positive_weights': positive_weights,
            'negative_weights': negative_weights,
            'zero_weights': zero_weights,
            'weight_mean': mean,
            'weight_std': std,
            'weight_min': p0,
            'weight_25': p25,
            'weight_50': p50,
            'weight_75': p75,
            'weight_max': p100,
        })

if args.log_infer_rule_check:  # note: this is not sound for MAGNNs
    model, weight_cutoff = get_model_and_weight_cutoff()
    if args.rule_channels_min_ratio != -1:
        weight_cutoff_model(model, weight_cutoff)

    eval_threshold_key = args.rule_channels_min_ratio
    assert eval_threshold_key in model.eval_thresholds, 'Optimal threshold must first be set on valid dataset'

    captured, rule_head_predicates_checkable = check_given_rules(
        model,
        f'{path_to_dataset}/final-rules-{args.dataset}.txt',
        canonical_encoder_file,
        iclr22_encoder_file,
        model.eval_thresholds[eval_threshold_key],
    )

    print('Rule capture states:', captured)
    print('Heads that are UP:', rule_head_predicates_checkable)

    captured_values = list(captured.values())
    log_infer_ratio_captured = captured_values.count(RuleCaptureStates.Yes) / len(captured_values)
    log_infer_ratio_no_neg_inf = captured_values.count(RuleCaptureStates.NoNegInf) / len(captured_values)
    log_infer_ratio_no_body_not_entail = captured_values.count(RuleCaptureStates.NoBodyNotEntail) / len(captured_values)
    log_infer_ratio_cannot_check = captured_values.count(RuleCaptureStates.CannotCheck) / len(captured_values)

    rule_head_predicates_checkable_values = list(rule_head_predicates_checkable.values())
    log_infer_ratio_up_heads = rule_head_predicates_checkable_values.count(True) / len(rule_head_predicates_checkable_values)

    if args.use_wandb:
        wandb.log({
            'log_infer_ratio_captured': log_infer_ratio_captured,
            'log_infer_ratio_no_neg_inf': log_infer_ratio_no_neg_inf,
            'log_infer_ratio_no_body_not_entail': log_infer_ratio_no_body_not_entail,
            'log_infer_ratio_cannot_check': log_infer_ratio_cannot_check,
            'log_infer_ratio_up_heads': log_infer_ratio_up_heads,
        })

if args.search_rule_check:
    model, weight_cutoff = get_model_and_weight_cutoff()
    if args.rule_channels_min_ratio != -1:
        weight_cutoff_model(model, weight_cutoff)

    eval_threshold_key = args.rule_channels_min_ratio
    assert eval_threshold_key in model.eval_thresholds, 'Optimal threshold must first be set on valid dataset'

    if 'lubm' in args.dataset:
        # rule explanations
        print('\n\n')
        print('Computing explanatory rules')
        avg_concepts_per_rule, sample_rules = explain_predictions(
            model,
            canonical_encoder_file,
            model.eval_thresholds[eval_threshold_key],
            test_graph,
            test_positive_examples,
            reduce_rule_size=args.reduce_rule_size,
        )

        print('avg_concepts_per_rule', avg_concepts_per_rule)
        if args.use_wandb:
            wandb.log({'avg_concepts_per_rule': avg_concepts_per_rule})
        print('Sample explanatory rules:')
        for prediction, rule in sample_rules:
            print('prediction:\n', prediction)
            print('rule:\n', rule)

    # 76 binary predicates in NELL -> 76 unary, 4 binary = choose 2 = 2850 bodies = 216600 rules
    # 237 binary in FB -> 237 unary, 4 binary = choose 1 = 237 bodies = 56169 rules
    # 11 binary in WN -> 11 unary, 4 binary = check all = 32768 bodies = 360448 rules
    # the above 3 use the iclr encoding
    # so for simplicity, we just check unary predicates, which correspond to binary predicates on the same pairs

    # 14 unary and 16 binary in lubm = choose 4 = 27405 bodies = 383670 rules
    if 'LogInfer-WN' in args.dataset:
        max_body_atoms_to_check = 15
    elif 'WN18RRv' in args.dataset:
        max_body_atoms_to_check = 15
    elif 'nellv' in args.dataset:
        max_body_atoms_to_check = 2
    elif 'fb237v' in args.dataset:
        max_body_atoms_to_check = 1
    elif 'lubm' in args.dataset:
        max_body_atoms_to_check = 4
    else:
        assert False, 'Dataset not yet supported for rule extraction'
    sound_counts, sample_rules, arity_counts = check_monotonic_rules(
        model,
        canonical_encoder_file,
        iclr22_encoder_file,
        encoding_scheme,
        model.eval_thresholds[eval_threshold_key],
        max_body_atoms_to_check,
    )

    count_unary, count_binary, count_mixed = arity_counts
    print('\n\n')
    print('Number of rules with only unary atoms:', count_unary)
    print('Number of rules with only binary atoms:', count_binary)
    print('Number of rules with a mix of unary and binary atoms:', count_mixed)

    if args.use_wandb:
        wandb.log({'count_unary': count_unary, 'count_binary': count_binary, 'count_mixed': count_mixed})

    for i in range(0, max_body_atoms_to_check + 1):
        print('Found', sound_counts[i], 'sound rules with', i, 'body atoms')
        print('Sample sound rule with', i, 'body atoms:', sample_rules[i])
        if args.use_wandb:
            wandb.log({'body_atoms': i, 'sound_count': sound_counts[i]})
