import argparse
from typing import Sequence, Tuple

import matplotlib.pyplot as plt
import json
import os

import numpy as np
import torch

import sys

sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.pardir, 'src')))
from data.dataset import DatasetMixin
from evaluation.ts_precision_recall import compute_window_indices
from utils.metadata import PROJECT_ROOT
from utils.sys_utils import check_path
from utils.utils import str2cls
from utils.plot_utils import save_figure
from data.statistics import compute_feature_statistics, compute_anomaly_positions, compute_anomaly_lengths, compute_total_time_steps


# Width of the text area in our document. Given in inches
TEXT_WIDTH = 5.11811  # 13cm
# Default aspect ratio of figures
ASPECT = 3 / 4
ASPECT_WIDE = 9 / 16

# Font sizes
SMALL_SIZE = 9
MEDIUM_SIZE = 10
BIGGER_SIZE = 11


def plot_distribution(ax, data_tensor: torch.Tensor, means: torch.Tensor, stds: torch.Tensor, color: str = 'tab:blue',
                      legend: bool = True):
    constant_features = np.array([i for i, std in enumerate(stds) if std == 0])
    nonconstant_features = np.array([i for i, std in enumerate(stds) if std > 0])

    ax.boxplot(data_tensor.numpy()[:, nonconstant_features], positions=nonconstant_features + 1, showfliers=False,
               boxprops=dict(color=color, facecolor=color), whiskerprops=dict(color=color), capprops=dict(color=color),
               patch_artist=True)

    label = 'Constant feature' if legend else None
    ax.scatter(constant_features + 1, means[constant_features], marker='x', color='green', label=label)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.get_xaxis().set_ticks([])


def generate_statistics(dataset: DatasetMixin, logpath: str, resolution: int = 100, train: bool = True):
    stats = {}

    # --- General statistics ---
    stats['n_features'] = dataset.num_features
    stats['n_samples'] = len(dataset)
    stats['total_time_steps'] = compute_total_time_steps(dataset)

    # --- Means and standard deviations
    normal_means, normal_stds, anomaly_means, anomaly_stds, _, _ = compute_feature_statistics(dataset)

    width = 0.9 * TEXT_WIDTH
    if anomaly_means is not None and normal_means is not None:
        fig, axes = plt.subplots(2, 1, figsize=(width, ASPECT_WIDE * width), sharex=True, sharey=True)
        ax = axes[0]
    else:
        fig = plt.figure(figsize=(width, 0.6 * ASPECT_WIDE * width))
        ax = plt.gca()
        axes = ax, ax

    normal_data = []
    anomaly_data = []

    for i in range(len(dataset)):
        inputs, targets = dataset[i]
        inputs = inputs[0]
        targets = targets[0]

        if normal_means is not None:
            normal_points = inputs[targets == 0]
            normal_data.append(normal_points)

        if anomaly_means is not None:
            anomaly_points = inputs[targets != 0]
            anomaly_data.append(anomaly_points)

    if normal_means is not None:
        plot_distribution(ax, torch.cat(normal_data, dim=0), normal_means, normal_stds)
        ax = axes[1]

    if anomaly_means is not None:
        plot_distribution(ax, torch.cat(anomaly_data, dim=0), anomaly_means, anomaly_stds, color='red')

    ax.legend(loc='upper right', bbox_to_anchor=(1, 0))
    fig.tight_layout()

    save_figure(os.path.join(logpath, 'train_feature_distribution.png' if train else 'test_feature_distribution.png'))

    # --- Constant features ---
    stats['n_constant_features_normal'] = 0 if normal_stds is None else len([i for i in normal_stds if i == 0])
    stats['n_constant_features_anomaly'] = 0 if anomaly_stds is None else len([i for i in anomaly_stds if i == 0])

    # --- Anomalies ---
    if anomaly_means is not None:
        # --- Distribution of relative anomaly positions in each time series ---
        positions = compute_anomaly_positions(dataset)

        ax = plt.gca()
        ax.hist(positions, resolution, (0, 1), histtype='stepfilled', color='k')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.get_xaxis().set_ticks([0, 1])

        save_figure(os.path.join(logpath, 'train_anomaly_positions.png' if train else 'test_anomaly_positions.png'))
        # --- Distribution of lengths of anomaly intervals in the dataset ---

        lengths = compute_anomaly_lengths(dataset)

        ax = plt.gca()
        ax.hist(lengths, resolution, (0, max(lengths)), histtype='stepfilled', color='k')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.get_xaxis().set_ticks([0, max(lengths)])

        save_figure(os.path.join(logpath, 'train_anomaly_lengths.png' if train else 'test_anomaly_lengths.png'))

        stats['total_anomalous_points'] = len(positions)
        stats['total_anomalies'] = len(lengths)

    # --- Write statistics to file ---
    with open(os.path.join(logpath, 'train_stats.json' if train else 'test_stats.json'), 'w') as fp:
        json.dump(stats, fp)


def plot_anomalies(dataset: DatasetMixin, logpath: str, plot_title: bool = False, plot_feature_name: bool = False,
                   padding: Tuple[int, int] = (1800, 3600),
                   selected_windows: Sequence[int] = None, selected_features: Sequence[int] = None):
    # TODO: maybe don't hardcode default padding but set relative to window size
    width = 0.9 * TEXT_WIDTH

    if selected_features is None:
        selected_features = list(range(dataset.num_features))

    for ts in range(len(dataset)):
        selected_windows_ts = selected_windows

        inputs, targets = dataset[ts]
        inputs = inputs[0]
        targets = targets[0]

        anomaly_windows = compute_window_indices(targets)

        if selected_windows_ts is None:
            selected_windows_ts = list(range(len(anomaly_windows)))

        for window in selected_windows_ts:
            start, end = anomaly_windows[window]
            plot_start = max(0, start - padding[0])
            plot_end = min(inputs.shape[0], end + padding[1])
            data = inputs[plot_start:plot_end]

            for feature in selected_features:
                feature_name = dataset.get_feature_names()[feature]
                factor = 0.75 if plot_title else 0.6
                fig = plt.figure(figsize=(width, factor * width * ASPECT_WIDE))
                ax = plt.gca()
                if plot_title:
                    ax.set_title(f'Anomaly {window + 1}')
                ax.set_xlabel('Time step')
                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)

                ax.plot(range(plot_start, plot_end), data[:, feature].numpy(), label=feature_name.replace('_', r'\_'))
                for other_start, other_end in anomaly_windows:
                    if (plot_start <= other_start < plot_end) or (plot_start < other_end <= plot_end):
                        ax.axvspan(max(other_start, plot_start), min(other_end, plot_end), color='red', alpha=0.5)
                if plot_feature_name:
                    ax.legend(loc='lower left')

                fig.tight_layout()
                save_path = os.path.join(logpath, f'time_series_{ts}', f'anomaly_{window}', f'feature_{feature}.png')
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                plt.savefig(save_path)
                plt.close(fig)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    #parser.add_argument('--dataset', type=str, default='data.swat_dataset.SWaTDataset')
    parser.add_argument('--dataset', type=str, default='data.smd_dataset.SMDDataset')
    args = parser.parse_args()

    dataset = str2cls(args.dataset)
    logpath = os.path.join(PROJECT_ROOT, 'resources', 'datasets', args.dataset.split('.')[-1])
    check_path(logpath)

    generate_statistics(dataset(training=True), logpath)
    test_data = dataset(training=False)
    generate_statistics(test_data, logpath, train=False)

    plot_anomalies(test_data, logpath)
