import os as os
import numpy as np
import json
import pickle
import gzip

import matplotlib.pyplot as plt


class Benchmark():
    """API for TabularBench."""
    
    def __init__(self, data_dir, cache=False, cache_dir="cached/"):
        """Initialize dataset (will take a few seconds-minutes).
        
        Keyword arguments:
        bench_data -- str, the raw benchmark data directory
        """
        if not os.path.isfile(data_dir) or not data_dir.endswith(".json"):
            raise ValueError("Please specify path to the bench json file.")
            
        self.data_dir = data_dir
        self.cache_dir = cache_dir
        self.cache = cache
        
        print("==> Loading data...")
        self.data = self._read_data(data_dir)
        self.dataset_names = list(self.data.keys())
        print("==> Done.")
        
    def query(self, dataset_name, tag, config_id):
        """Query a run.
        
        Keyword arguments:
        dataset_name -- str, the name of the dataset in the benchmark
        tag -- str, the tag you want to query
        config_id -- int, an identifier for which run you want to query, if too large will query the last run
        """
        config_id = str(config_id)
        if dataset_name not in self.dataset_names:
            raise ValueError("Dataset name not found.")
        
        if config_id not in self.data[dataset_name].keys():
            raise ValueError("Config nr %s not found for dataset %s." % (config_id, dataset_name))
        
        if tag in self.data[dataset_name][config_id]["log"].keys():
            return self.data[dataset_name][config_id]["log"][tag]
        
        if tag in self.data[dataset_name][config_id]["results"].keys():
            return self.data[dataset_name][config_id]["results"][tag]
        
        if tag in self.data[dataset_name][config_id]["config"].keys():
            return self.data[dataset_name][config_id]["config"][tag]
        
        if tag == "config":
            return self.data[dataset_name][config_id]["config"]
            
        raise ValueError("Tag %s not found for config %s for dataset %s" % (tag, config_id, dataset_name))
        
    def query_best(self, dataset_name, tag, criterion, position=0):
        """Query the n-th best run. "Best" here means achieving the largest value at any epoch/step,
        
        Keyword arguments:
        dataset_name -- str, the name of the dataset in the benchmark
        tag -- str, the tag you want to query
        criterion -- str, the tag you want to use for the ranking
        position -- int, an identifier for which position in the ranking you want to query
        """
        performances = []
        for config_id in self.data[dataset_name].keys():
            performances.append((config_id, max(self.query(dataset_name, criterion, config_id))))

        performances.sort(key=lambda x: x[1]*1000, reverse=True)
        desired_position = performances[position][0]

        return self.query(dataset_name, tag, desired_position)
        
    def get_queriable_tags(self, dataset_name=None, config_id=None):
        """Returns a list of all queriable tags"""
        if dataset_name is None or config_id is None:
            dataset_name = list(self.data.keys())[0]
            config_id = list(self.data[dataset_name].keys())[0]
        else:
            config_id = str(config_id)
        log_tags = list(self.data[dataset_name][config_id]["log"].keys())
        result_tags = list(self.data[dataset_name][config_id]["results"].keys())
        config_tags = list(self.data[dataset_name][config_id]["config"].keys())
        additional = ["config"]
        return log_tags+result_tags+config_tags+additional
    
    def get_dataset_names(self):
        """Returns a list of all availabe dataset names like defined on openml"""
        return self.dataset_names
    
    def get_openml_task_ids(self):
        """Returns a list of openml task ids"""
        task_ids = []
        for dataset_name in self.dataset_names:
            task_ids.append(self.query(dataset_name, "OpenML_task_id", 1))
        return task_ids
    
    def get_number_of_configs(self, dataset_name):
        """Returns the number of configurations for a dataset"""
        if dataset_name not in self.dataset_names:
            raise ValueError("Dataset name not found.")
        return len(self.data[dataset_name].keys())
    
    def get_config(self, dataset_name, config_id):
        """Returns the configuration of a run specified by dataset name and config id"""
        if dataset_name not in self.dataset_names:
            raise ValueError("Dataset name not found.")
        return self.data[dataset_name][config_id]["config"]
        
    def plot_by_name(self, dataset_names, x_col, y_col, n_configs=10, show_best=False, xscale='linear', yscale='linear', criterion=None):
        """Plot multiple datasets and multiple runs.
        
        Keyword arguments:
        dataset_names -- list
        x_col -- str, tag to plot on x-axis
        y_col -- str, tag to plot on y-axis
        n_configs -- int, number of configs to plot for each dataset
        show_best -- bool, weather to show the n_configs best (according to query_best())
        xscale -- str, set xscale, options as in matplotlib: "linear", "log", "symlog", "logit", ...
        yscale -- str, set yscale, options as in matplotlib: "linear", "log", "symlog", "logit", ...
        criterion -- str, tag used as criterion for query_best()    
        """
        if isinstance(dataset_names, str):
            dataset_names = [dataset_names]
        if not isinstance(dataset_names, (list, np.ndarray)):
            raise ValueError("Please specify a dataset name or a list list of dataset names.")
    
        n_rows = len(dataset_names)
        fig, axes = plt.subplots(n_rows, 1, sharex=False, sharey=False, figsize=(10,7*n_rows))
    
        if criterion is None:
            criterion = y_col
            
        loop_arg = enumerate(axes.flatten()) if len(dataset_names)>1 else [(0,axes)]
    
        for ind_ax, ax in loop_arg:
            for ind in range(n_configs):
                try:
                    if ind==0:
                        instances = int(self.query(dataset_names[ind_ax], "instances", 0))
                        classes = int(self.query(dataset_names[ind_ax], "classes", 0))
                        features = int(self.query(dataset_names[ind_ax], "features", 0))
            
                    if show_best:
                        x = self.query_best(dataset_names[ind_ax], x_col, criterion, ind)
                        y = self.query_best(dataset_names[ind_ax], y_col, criterion, ind)
                    else:
                        x = self.query(dataset_names[ind_ax], x_col, ind+1)
                        y = self.query(dataset_names[ind_ax], y_col, ind+1)
                        
                    ax.plot(x, y, 'p-')
                    ax.set_xscale(xscale)
                    ax.set_yscale(yscale)
                    ax.set(xlabel="step", ylabel=y_col)
                    title_str = ", ".join([dataset_names[ind_ax],
                                          "features: " + str(features),
                                          "classes: " + str(classes),
                                          "instances: " + str(instances)])
                    ax.title.set_text(title_str)
                except ValueError:
                    print("Run %i not found for dataset %s" %(ind, dataset_names[ind_ax]))
                except Exception as e:
                    raise e
                    
    def _cache_data(self, data, cache_file):
        os.makedirs(self.cache_dir, exist_ok=True)
        with gzip.open(cache_file, 'wb') as f:
            pickle.dump(data, f)
    
    def _read_cached_data(self, cache_file):
        with gzip.open(cache_file, 'rb') as f:
            data = pickle.load(f)
        return data
                    
    def _read_file_string(self, path):
        """Reads a large json string from path. Python file handler has issues with large files so it has to be chunked."""
        # Shoutout to https://stackoverflow.com/questions/48122798/oserror-errno-22-invalid-argument-when-reading-a-huge-file
        file_str = ''
        with open(path, 'r') as f:
            while True:
                block = f.read(64 * (1 << 20)) # Read 64 MB at a time
                if not block:                  # Reached EOF
                    break
                file_str += block
        return file_str
        
    def _read_data(self, path):
        """Reads cached data if available. If not, reads json and caches the data as .pkl.gz"""
        cache_file = os.path.join(self.cache_dir, os.path.basename(self.data_dir).replace(".json", ".pkl.gz"))
        if os.path.exists(cache_file) and self.cache:
            print("==> Found cached data, loading...")
            data = self._read_cached_data(cache_file)
        else:
            print("==> No cached data found or cache set to False.")
            print("==> Reading json data...")
            data = json.loads(self._read_file_string(path))
            if self.cache:
                print("==> Caching data...")
                self._cache_data(data, cache_file)
        return data
