import pathlib
import json
import warnings
from typing import Optional, List, Dict

import numpy as np
from natsort import natsort

from flask import Flask, render_template, request, redirect, url_for
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--path', action='store', dest='path')
# Experiments in the directory subtree, rooted in the path argument, are parents of "raw_data" subdirectories.
args = parser.parse_args()
root = pathlib.Path(args.path)


class Dataset:
    def __init__(self, path: pathlib.Path):
        self.path = path
        self.raw_data_path = self.path / 'raw_data'
        self.dla_path = self.raw_data_path / 'dla_particles'
        self.flow_path = self.raw_data_path / 'flow_samples'
        self.static_path = self.raw_data_path / 'static'

        self.numpy_data: Dict[str, Optional[np.ndarray]] = dict(
            dla=np.empty(shape=()),
            dla_minima=None,
            dla_maxima=None,
            flow=None,
            flow_minima=None,
            flow_maxima=None,
            likelihood=None
        )
        
        self.info = dict(
            nSteps=0,
            nDimensions=0,
            nDLAParticles=0,
            nFlowSamples=0
        )

    @staticmethod
    def compute_limits(data: np.ndarray):
        minima = np.min(np.min(np.transpose(data), axis=0), axis=1)
        maxima = np.max(np.max(np.transpose(data), axis=0), axis=1)
        return minima, maxima

    def print_data_stats(self):
        print('=' * 100)
        print('Experiment info')
        print(f'- Path: {self.path.absolute()}')
        print(f'- Number of steps: {self.info["nSteps"]}')
        print(f'- Number of dimensions: {self.info["nDimensions"]}')
        print('-' * 100)
        print(f'- DLA data size: {self.numpy_data["dla"].nbytes / 10 ** 6} MB')
        print(f'- DLA particles: {self.info["nDLAParticles"]}')
        print(f'- DLA data shape: {self.numpy_data["dla"].shape}')
        print(f'- DLA data minima: {self.numpy_data["dla_minima"]}')
        print(f'- DLA data maxima: {self.numpy_data["dla_maxima"]}')
        print('-' * 100)
        if self.numpy_data["flow"] is not None:
            print(f'- Flow data size: {self.numpy_data["flow"].nbytes / 10 ** 6} MB')
            print(f'- Flow samples: {self.info["nFlowSamples"]}')
            print(f'- Flow data shape: {self.numpy_data["flow"].shape}')
            print(f'- Flow data minima: {self.numpy_data["flow_minima"]}')
            print(f'- Flow data maxima: {self.numpy_data["flow_maxima"]}')
        else:
            print(f'- No flow data')
        print('-' * 100)
        if self.numpy_data["likelihood"] is not None:
            print(f'- Likelihood data size: {self.numpy_data["likelihood"].nbytes / 10 ** 6} MB')
            print(f'- Likelihood data shape: {self.numpy_data["likelihood"].shape}')
        else:
            print('- No likelihood data')
        print('=' * 100)

    def load_static_data(self):
        # TODO "hardcode" the likelihood path
        static_file_paths = list(self.static_path.rglob('*.npy'))
        for file_path in static_file_paths:
            print(f'Loading static data: {file_path.stem}')
            self.numpy_data[file_path.stem] = np.load(str(file_path)).T

    def load_dla_data(self):
        numpy_files = natsort.natsorted(list(self.dla_path.glob('*.npy')))
        if len(numpy_files) == 0:
            warnings.warn(f'No .npy files in {str(self.dla_path.absolute())}')
            return

        data = np.stack([np.load(str(f)) for f in numpy_files])
        data = np.transpose(data, (0, 2, 1))
        minima, maxima = self.compute_limits(data)

        self.info['nSteps'] = data.shape[0]
        self.info['nDimensions'] = data.shape[1]
        self.info['nDLAParticles'] = data.shape[2]

        self.numpy_data["dla"] = data
        self.numpy_data["dla_minima"] = minima
        self.numpy_data["dla_maxima"] = maxima


    def load_flow_data(self):
        numpy_files = natsort.natsorted(list(self.flow_path.glob('*.npy')))
        if len(numpy_files) == 0:
            warnings.warn(f'No .npy files in {str(self.flow_path.absolute())}')
            return

        data = np.stack([np.load(str(f)) for f in numpy_files])
        data = np.transpose(data, (0, 2, 1))
        minima, maxima = self.compute_limits(data)

        self.info['nSteps'] = data.shape[0]
        self.info['nDimensions'] = data.shape[1]
        self.info['nFlowSamples'] = data.shape[2]

        self.numpy_data["flow"] = data
        self.numpy_data["flow_minima"] = minima
        self.numpy_data["flow_maxima"] = maxima

    def load(self):
        self.load_static_data()
        self.load_dla_data()
        self.load_flow_data()

        self.print_data_stats()

    @property
    def data(self):
        return {k: (v.tolist() if v is not None else []) for k, v in self.numpy_data.items()}


class DatasetManager:
    def __init__(self, path: pathlib.Path):
        self.path = path
        self.experiment_directories = natsort.natsorted([p.parent for p in self.path.rglob('raw_data')])
        if len(self.experiment_directories) == 0:
            raise ValueError(f"No experiment directories found in {str(self.path)}")
        print(f'Total experiments: {len(self.experiment_directories)}')

        self.active_index = 0
        self.dataset: Dataset = Dataset(self.experiment_directories[self.active_index])
        self.dataset.load()

    def set_active_dataset(self, dataset_index):
        if dataset_index == self.active_index:
            return

        self.active_index = dataset_index
        self.dataset = Dataset(self.experiment_directories[self.active_index])
        self.dataset.load()

    def reload_active_dataset(self):
        self.dataset.load()

    @property
    def relative_paths(self):
        return [str(p.relative_to(self.path)) for p in self.experiment_directories]

    @property
    def meta_info(self):
        return {
            "experiment_id": self.active_index,
            "default_experiment_id_flask": self.active_index + 1
        }


dataset_manager = DatasetManager(root)

app = Flask(__name__)


@app.route('/reload', methods=['POST'])
def reload():
    global dataset_manager

    print('Reloading experiment list')
    dataset_manager = DatasetManager(root)

    # print('Reloading data')
    # dataset_manager.reload_active_dataset()

    return redirect(url_for('home'))


@app.route('/', methods=['GET', 'POST'])
def home():
    global dataset_manager

    print('Going home')

    # +1, -1 because flask indices starts at 1
    new_experiment_id = int(request.form.get('experimentRadio', dataset_manager.active_index + 1)) - 1
    dataset_manager.set_active_dataset(new_experiment_id)

    return render_template(
        'dla_dashboard_home.html',
        data=json.dumps(dataset_manager.dataset.data),
        experiment_info=dataset_manager.dataset.info,
        experiment_paths=dataset_manager.relative_paths,
        meta_info=dataset_manager.meta_info
    )


if __name__ == "__main__":
    app.run()
