import argparse
import random

import numpy as np
from lightning.pytorch import seed_everything

from evaluation.utils import get_nfeatures_seqlen_maxlag, get_runid_and_epoch, get_data_synthetic


def get_valmatrix_graph_strongest_all(coefficients: np.ndarray, keep_top: float, n_features, max_lag) -> (
        np.ndarray, np.ndarray):
    # coefficients.shape = [n_samples, n_features, n_features*max_lag, seq_len-max_lag]
    n_samples = coefficients.shape[0]

    c_pos = np.quantile(coefficients, q=.95, axis=-1)
    c_neg = np.quantile(coefficients, q=.05, axis=-1)
    c = np.asarray(
        [
            [p if p > abs(n) else n for p, n in zip(pos.flatten(), neg.flatten())]
            for pos, neg in zip(c_pos, c_neg)
        ]
    )
    c = c.reshape(n_samples, n_features, n_features, max_lag).transpose((0, 2, 1, 3))
    c = np.flip(c, axis=-1)
    z = np.zeros((n_samples, n_features, n_features, 1))
    val_matrix = np.concatenate([z, c], axis=-1)
    val_matrix = val_matrix.flatten()

    threshold_pos = np.quantile(val_matrix, q=1 - keep_top)
    threshold_neg = np.quantile(val_matrix, q=keep_top)

    val_matrix = np.where(val_matrix > threshold_pos, 1., (np.where(val_matrix < threshold_neg, -1., 0.)))
    # val_matrix.shape = [n_samples, n_features, n_features, max_lag + 1]
    return val_matrix.reshape(n_samples, n_features, n_features, max_lag + 1)


def main():
    seed_everything(0)
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--run', required=True)
    parser.add_argument('-d', '--dataset', required=True)
    args = parser.parse_args()

    run = args.run
    dataset = args.dataset

    n_features, seq_len, max_lag = get_nfeatures_seqlen_maxlag(dataset)

    n_samples_evaluation = 1024
    keep_top = .01

    run_id, epoch = get_runid_and_epoch(run, dataset)

    print(f'Dataset: {dataset}\tRun: {run}\tID: {run_id}\tEpoch: {epoch}')

    links_we_dont_want = list()
    val_matrix_we_dont_want = np.zeros((n_features, n_features, max_lag + 1))

    if 'rivers' in dataset:
        links_we_dont_want.extend(
            [
                (0, 0, 1), (0, 0, 2), (0, 1, 1), (0, 1, 2), (0, 2, 1), (0, 2, 2),
                (1, 1, 1), (1, 1, 2), (1, 0, 2), (1, 2, 1), (1, 2, 2),
                (2, 2, 1), (2, 2, 2), (2, 1, 1), (2, 1, 2), (2, 0, 2), (2, 0, 1),
            ]
        )
    else:  # dataset == 'henon'
        for f1 in range(n_features):
            for f2 in range(n_features):
                if f1 == f2:
                    links_we_dont_want.append((f1, f2, 1))
                else:
                    links_we_dont_want.append((f1, f2, 1))
                    links_we_dont_want.append((f1, f2, 2))

        for f1 in range(n_features):
            for f2 in range(n_features):
                if f1 == f2:
                    links_we_dont_want.append((f1, f2, 2))
                elif f1 + 1 == f2:
                    links_we_dont_want.append((f1, f2, 2))
                elif f1 < f2:
                    links_we_dont_want.append((f1, f2, 1))
                    links_we_dont_want.append((f1, f2, 2))
                elif f1 > f2:
                    links_we_dont_want.append((f1, f2, 1))
                    links_we_dont_want.append((f1, f2, 2))

    i, j, k = zip(*links_we_dont_want)
    val_matrix_we_dont_want[i, j, k] = 1

    scores = list()
    for seed in range(10):
        random.seed(seed)
        np.random.seed(seed)

        coefficients = get_data_synthetic(run_id, epoch, seed, seq_len, n_features, coefficients=True)
        all_val_matrix = get_valmatrix_graph_strongest_all(coefficients, keep_top, n_features, max_lag)

        random_indexes = np.random.randint(0, len(coefficients), n_samples_evaluation)
        all_val_matrix = all_val_matrix[random_indexes]

        error_rates = list()
        for val_matrix in all_val_matrix:
            n_errors = np.sum(
                [np.abs(i) == j == 1 for i, j in zip(val_matrix.flatten(), val_matrix_we_dont_want.flatten())]
            )
            error_rate = n_errors / len(links_we_dont_want)
            error_rates.append(error_rate)

        score = np.mean(error_rates)
        scores.append(score)
        # print(f'Seed: {seed}\tScore: {str(score).replace(".", ",")}')

    mean = np.mean(scores)
    std = np.std(scores)
    print(f'Mean: {str(mean).replace(".", ",")}\tStd: {str(std).replace(".", ",")}')


if __name__ == '__main__':
    main()
