import argparse
import random

import numpy as np
from lightning.pytorch import seed_everything

from evaluation.utils import get_nfeatures_seqlen_maxlag, get_runid_and_epoch, get_data_synthetic


def get_valmatrix_graph_strongest_all(coefficients: np.ndarray, keep_top: float, n_features, max_lag) -> (
        np.ndarray, np.ndarray):
    # coefficients.shape = [n_samples, n_features, n_features*max_lag, seq_len-max_lag]
    n_samples = coefficients.shape[0]

    c_pos = np.quantile(coefficients, q=.95, axis=-1)
    c_neg = np.quantile(coefficients, q=.05, axis=-1)
    c = np.asarray(
        [
            [p if p > abs(n) else n for p, n in zip(pos.flatten(), neg.flatten())]
            for pos, neg in zip(c_pos, c_neg)
        ]
    )
    c = c.reshape(n_samples, n_features, n_features, max_lag).transpose((0, 2, 1, 3))
    c = np.flip(c, axis=-1)
    z = np.zeros((n_samples, n_features, n_features, 1))
    val_matrix = np.concatenate([z, c], axis=-1)
    val_matrix = val_matrix.flatten()

    threshold_pos = np.quantile(val_matrix, q=1 - keep_top)
    threshold_neg = np.quantile(val_matrix, q=keep_top)

    val_matrix = np.where(val_matrix > threshold_pos, 1., (np.where(val_matrix < threshold_neg, -1., 0.)))
    # val_matrix.shape = [n_samples, n_features, n_features, max_lag + 1]
    return val_matrix.reshape(n_samples, n_features, n_features, max_lag + 1)


def main():
    seed_everything(0)
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--run', required=True)
    parser.add_argument('-d', '--dataset', required=True)
    args = parser.parse_args()

    run = args.run
    dataset = args.dataset

    n_features, seq_len, max_lag = get_nfeatures_seqlen_maxlag(dataset)

    n_samples_evaluation = 1024
    keep_top = .01

    run_id, epoch = get_runid_and_epoch(run, dataset)

    print(f'Dataset: {dataset}\tRun: {run}\tID: {run_id}\tEpoch: {epoch}')

    if 'rivers' in dataset:
        gc_real = np.asarray([[0, 0, 0], [1, 0, 0], [0, 0, 0]])
    elif dataset == 'henon6':
        gc_real = np.asarray(
            [
                [1, 1, 0, 0, 0, 0],
                [0, 1, 1, 0, 0, 0],
                [0, 0, 1, 1, 0, 0],
                [0, 0, 0, 1, 1, 0],
                [0, 0, 0, 0, 1, 1],
                [0, 0, 0, 0, 0, 1],
            ]
        )
    else:
        gc_real = np.load('data/airquality/gc_real.npy')
    # gc_real = np.asarray([gc_real] * n_samples_evaluation).reshape(n_samples_evaluation, n_features * n_features)
    gc_real = gc_real.flatten()

    scores = list()
    for seed in range(10):
        random.seed(seed)
        np.random.seed(seed)

        coefficients = get_data_synthetic(run_id, epoch, seed, seq_len, n_features, coefficients=True)
        all_val_matrix = get_valmatrix_graph_strongest_all(coefficients, keep_top, n_features, max_lag)

        random_indexes = np.random.randint(0, len(coefficients), n_samples_evaluation)
        all_val_matrix = all_val_matrix[random_indexes]

        gc_synthetic = np.asarray(
            [np.max(val_matrix, axis=-1) for val_matrix in all_val_matrix]
        ).reshape(n_samples_evaluation, n_features * n_features)

        not_gc_real = 1 - gc_real
        n = np.sum(not_gc_real)
        fp = np.asarray([(i == not_gc_real) == 1 for i in gc_synthetic]).sum(1)
        score = np.mean(fp / n)

        # cr = classification_report(gc_real, gc_synthetic, output_dict=True, zero_division=1)
        # score = cr['weighted avg']['f1-score']

        scores.append(score)
        # print(f'Seed: {seed}\tScore: {str(score).replace(".", ",")}')

    mean = np.mean(scores)
    std = np.std(scores)
    print(f'Mean: {str(mean).replace(".", ",")}\tStd: {str(std).replace(".", ",")}')


if __name__ == '__main__':
    main()
