import argparse
import json
import os
import subprocess
import sys
import tempfile
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
from sacred.serializer import restore
from sklearn.metrics import precision_recall_fscore_support, precision_recall_curve

from exathlon.ad_evaluators import RangeEvaluator

sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir, 'src')))
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from evaluation import inverse_proportional_cardinality_fn, Evaluator
from evaluation.ts_precision_recall import improved_cardinality_fn, constant_bias_fn, front_bias_fn, middle_bias_fn, \
    back_bias_fn, ts_precision_and_recall
from experiment_utils import data_ingredient, make_experiment, \
    make_experiment_tempfile
from utils.plot_utils import setup_matplotlib, TEXT_WIDTH, ASPECT
from utils.torch_utils import set_threads


experiment = make_experiment(ingredients=[data_ingredient])

EVALUATOR_PATH = os.path.realpath('reference_implementation')
EXE = 'evaluate.exe' if sys.platform == 'win32' else 'evaluate'

BIAS_FN_MAP = {
    'flat': constant_bias_fn,
    'front': front_bias_fn,
    'middle': middle_bias_fn,
    'back': back_bias_fn
}

CARDINALITY_FN_MAP = {
    'reciprocal': inverse_proportional_cardinality_fn,
    'recall_consistent': improved_cardinality_fn
}


def get_reference_results(labels, predictions, alpha: float, cardinality: str, r_bias: str, p_bias: str):
    # Run reference
    with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.real') as real_file:
        np.savetxt(real_file, labels, delimiter='\n', fmt='%d')

    with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.pred') as pred_file:
        np.savetxt(pred_file, predictions, delimiter='\n', fmt='%d')

    program = os.path.join(EVALUATOR_PATH, 'src', EXE)
    result = subprocess.run([
        program, '-t',  # We want the TS score
        real_file.name,  # First comes the ground-truth file
        pred_file.name,  # Then the prediction file
        '1',  # We don't care about the F-score, so beta does not matter
        str(alpha),  # Specify alpha
        cardinality,  # Cardinality function to use
        p_bias,  # And finally the bias functions for precision and recall
        r_bias
    ],
    capture_output=True, text=True)

    os.remove(real_file.name)
    os.remove(pred_file.name)

    if result.returncode != 0:
        raise RuntimeError('Some Error occurred while running the reference implementation:', result.stderr)

    p_line, r_line, _, runtime_line = result.stdout.splitlines(keepends=False)
    precision = float(p_line.split(' = ')[1])
    recall = float(r_line.split(' = ')[1])
    runtime = float(runtime_line.split(' = ')[1]) / 1000000

    return precision, recall, runtime


def get_our_results(labels, predictions, scores, alpha: float, cardinality: str, r_bias: str, p_bias: str):
    labels = torch.from_numpy(labels).to(torch.int64)
    predictions = torch.from_numpy(predictions).to(torch.int64)

    cardinality = CARDINALITY_FN_MAP[cardinality]
    r_bias = BIAS_FN_MAP[r_bias]
    p_bias = BIAS_FN_MAP[p_bias]

    start = time.perf_counter()
    precision, recall = ts_precision_and_recall(labels, predictions, alpha, r_bias, cardinality, p_bias, cardinality)
    end = time.perf_counter()

    return precision, recall, end - start


def get_our_best_results(labels, predictions, scores, alpha: float, cardinality: str, r_bias: str, p_bias: str):
    labels = torch.from_numpy(labels).to(torch.int64)
    scores = torch.from_numpy(scores)

    cardinality = CARDINALITY_FN_MAP[cardinality]
    r_bias = BIAS_FN_MAP[r_bias]
    p_bias = BIAS_FN_MAP[p_bias]

    evaluator = Evaluator()

    start = time.perf_counter()
    best_f1_score, info = evaluator.best_ts_f1_score(labels, scores, alpha=alpha, recall_bias_fn=r_bias,
                                                     recall_cardinality_fn=cardinality, precision_bias_fn=p_bias,
                                                     precision_cardinality_fn=cardinality)
    end = time.perf_counter()

    return best_f1_score, best_f1_score, end - start


def get_exathlon_results(labels, predictions, scores, alpha: float, cardinality: str, r_bias: str, p_bias: str):
    args = {
        'evaluation_type': 'range',
        'recall_alpha': alpha,
        'recall_omega': 'default', 'recall_delta': r_bias, 'recall_gamma': 'inv.poly',
        'precision_omega': 'default', 'precision_delta': p_bias, 'precision_gamma': 'inv.poly',
        'f_score_beta': 1.0
    }
    args = argparse.Namespace(**args)
    evaluator = RangeEvaluator(args)

    start = time.perf_counter()
    f_scores, precision, recalls = evaluator.compute_period_metrics(labels, predictions)
    end = time.perf_counter()

    return precision, recalls[1], end - start


def get_pointwise_results(labels, predictions, alpha: float, cardinality: str, r_bias: str, p_bias: str):
    start = time.perf_counter()
    precision, recall, _, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    end = time.perf_counter()

    return precision, recall, end - start


def get_pointwise_best_results(labels, predictions, scores, alpha: float, cardinality: str, r_bias: str, p_bias: str):
    start = time.perf_counter()
    precision, recall, thresholds = precision_recall_curve(labels, scores)
    f_score = np.nan_to_num(2 * precision * recall / (precision + recall), nan=0)
    best_index = np.argmax(f_score)
    end = time.perf_counter()

    return f_score[best_index].item(), f_score[best_index].item(), end - start


def generate_random_labels(size: int, anomaly_percentage: float):
    anomlous_points = int(anomaly_percentage * size)

    n_windows = np.random.randint(1, anomlous_points)
    average_window_length = anomlous_points / n_windows

    result = np.zeros(size)
    w_lengths = np.random.binomial(n=2 * average_window_length, p=0.5, size=n_windows)
    for length in w_lengths:
        start = np.random.randint(0, size - length)
        end = start + length

        result[start:end] = 1

    return result


@data_ingredient.config
def data_config():
    name = 'SWaT'
    ds_args = dict(
        training=False
    )

    use_dataset_pipeline = False

    split = (1,)


DISPLAY_NAMES = {
    'reference': 'Reference',
    'our': 'Ours',
    'our_best': 'Ours',
    'exathlon': 'Jacob et al.',
    'pointwise': 'Point-wise',
    'pointwise_best': 'Point-wise'
}


@experiment.config
def config():
    # Exathlon on 100000 (one run!) already takes more than half an hour
    implementations = ['reference', 'our', 'exathlon', 'pointwise']
    sizes = [1000, 10000, 100000, 500000, 1000000]
    anom_prob = 0.1
    detector_skill  = 0.8
    runs = 50

    info_file = None


@experiment.automain
def main(implementations, sizes, anom_prob, detector_skill, runs, info_file, _run, _log):
    set_threads(1)

    sizes = [s for s in sizes]

    if info_file is not None:
        with open(info_file, 'r') as f:
            info = json.load(f)
            info = restore(info)

            mean_runtimes = info['mean_times']
            std_runtimes = info['std_times']
    else:
        times = dict()
        for i in implementations:
            times[i] = {size: [] for size in sizes}

        try:
            for size in sizes:
                print('Size:', size)
                for run in range(runs):
                    print(f'Starting run {run + 1} of {runs}!')
                    labels = generate_random_labels(size, anomaly_percentage=anom_prob)
                    if np.random.rand() >= detector_skill:
                        scores = np.random.rand(size)
                        predictions = generate_random_labels(size, anomaly_percentage=anom_prob)
                    else:
                        scores = labels + np.random.randn(size) * 0.5
                        threshold = np.random.choice(scores)
                        predictions = (scores > threshold).astype(np.int32)

                    precision, recall = [], []
                    for imp in implementations:
                        func = globals()[f'get_{imp}_results']
                        p, r, runtime = func(labels, predictions, scores, 0, 'reciprocal', 'flat', 'flat')
                        precision.append(p)
                        recall.append(r)
                        times[imp][size].append(runtime)

                    if not np.all(np.isclose(precision, precision[0])):
                        _log.warning(f'Results for precision @ size {size} are inconsistent! Results are {precision} '
                                  f'for {implementations}, respectively.')

                    if not np.all(np.isclose(recall, recall[0])):
                        _log.warning(f'Results for recall @ size {size} are inconsistent! Results are {recall} '
                                  f'for {implementations}, respectively.')
        except KeyboardInterrupt:
            pass

        mean_runtimes = {imp: np.array([np.mean(runtimes) for size, runtimes in sizes.items()]) for imp, sizes in times.items()}
        std_runtimes = {imp: np.array([np.std(runtimes) for size, runtimes in sizes.items()]) for imp, sizes in times.items()}

        _run.info['times'] = times
        _run.info['mean_times'] = mean_runtimes
        _run.info['std_times'] = std_runtimes

    setup_matplotlib(8, 9)

    width = 0.5 * TEXT_WIDTH
    fig = plt.figure(figsize=(width, ASPECT * width))
    ax = plt.gca()
    ax.set_xscale('log')
    ax.set_xlabel('$T$')
    ax.set_ylabel('Runtime [s]')
    for imp in implementations:
        ax.plot(sizes, mean_runtimes[imp], label=DISPLAY_NAMES[imp])
        ax.fill_between(sizes, mean_runtimes[imp] - std_runtimes[imp], mean_runtimes[imp] + std_runtimes[imp],
                        alpha=0.2)
    ax.legend(loc='upper left')

    fig.tight_layout()
    with make_experiment_tempfile('plot.pdf', _run) as f:
        plt.savefig(f, format='pdf')
    plt.close(fig)
