import glob
import pickle

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from db.agents.alg_1 import AgentAlg1
from db.agents.alg_2 import AgentAlg2
from db.agents.random import AgentRandom
from db.agents.versatile_db import AgentVersatileDB
from db.agents.wr_exp3_ix import AgentWrExp3IX
from db.agents.ws_w import AgentWSW
from db.envs.env import Env


class BaseExperiment:
    def __init__(self, experiment_id=-1, P=(np.zeros((5, 5)) + 0.5), use_history=True, debug=False):
        self.debug = debug
        self.env = Env(P)
        self.id = self.get_experiment_id(experiment_id)
        self.files = {
            'data': r'output\data\exp_' + str(self.id) + '.pkl',
            'pdf': r'output\pdf\exp_' + str(self.id),
            'png': r'output\png\exp_' + str(self.id)
        }
        self.colors = ['blue', 'orange', 'green', 'red']

        self.history = self.get_history_object(use_history)

    def reset(self):
        pass

    def get_experiment_id(self, id):
        if id == -1:
            return 0    # todo: return next possible id that is not used
        else:
            return id

    def get_history_object(self, use_history):
        if glob.glob(self.files['data']) and use_history:
            with open(self.files['data'], 'rb') as inp:
                history = pickle.load(inp)
                self.env = Env(history.P)
            # todo: display warning if env.P does not match history.P
        else:
            history = History()
            history.P = self.env.P
            history.results = {}
            with open(self.files['data'], 'wb') as outp:
                pickle.dump(history, outp, -1)

        return history

    def get_env(self):
        return Env(self.history.P)

    def print(self):
        print(self.history.P)
        print(self.history.results)

    def get_agent_algorithm(self, agent):
        if agent == 'WS-W':
            return AgentWSW(self.env, name="WS-W")
        if agent == 'WR-UCB':
            return AgentAlg1(self.env, name="WR-UCB")
        if agent == 'WR-TINF':
            # return AgentAlg1(self.env, name="WR-UCB")
            return AgentAlg2(self.env, name="WR-TINF")
        if agent == 'Versatile-DB':
            return AgentVersatileDB(self.env, name="Vesrsatile-DB")
        if agent == 'WR-EXP3-IX':
            return AgentWrExp3IX(self.env, name="WR-EXP3-IX")

        print(f'Warning: {agent} not found, AgentRandom used instead')
        return AgentRandom(self.env)

    def run(self, T=100, agents=[], iterations=1, use_history=True):
        for agent in agents:
            if self.debug:
                print(f'Running experiment for {agent}')
            agent_algorithm = self.get_agent_algorithm(agent)
            if agent not in self.history.results.keys() or not use_history:
                self.history.results[agent] = np.empty((0, T, 3))   # 3 metrics
                self.history.actions[agent] = np.empty((0, T, 2))   # 2 actions
            if self.history.results[agent].shape[1] != T:
                self.history.results[agent] = np.empty((0, T, 3))
                self.history.actions[agent] = np.empty((0, T, 2))

            for iteration in range(iterations):
                if iteration < self.history.results[agent].shape[0]:
                    # print(f'iretation {iteration} already exists')
                    continue
                else:
                    print(f'Iteration: {iteration}')

                agent_algorithm.play(T=T, reset=True)

                temp_results = np.vstack((agent_algorithm.weak_regret, agent_algorithm.strong_regret, agent_algorithm.observations))

                self.history.results[agent] = np.append(self.history.results[agent], temp_results.transpose().reshape((1, T, 3)), axis=0)
                self.history.actions[agent] = np.append(self.history.actions[agent], agent_algorithm.actions.transpose().reshape((1, T, 2)), axis=0)

        with open(self.files['data'], 'wb') as outp:
            pickle.dump(self.history, outp, -1)

    def plot(self, agents=[], metrics=['weak regret'], T=-1, iterations=-1, save_png=True, save_pdf=True, title='Dueling bandits', ylabel='Regret', xlabel='Time', file_suffix='', quantile=0.9):
        matplotlib.rcParams['mathtext.fontset'] = 'custom'
        matplotlib.rcParams['mathtext.fontset'] = 'stix'
        matplotlib.rcParams['font.family'] = 'STIXGeneral'

        plt.rc('axes', titlesize=16)     # fontsize of the axes title
        plt.rc('axes', labelsize=16)    # fontsize of the x and y labels
        plt.rc('xtick', labelsize=12)    # fontsize of the tick labels
        plt.rc('ytick', labelsize=12)    # fontsize of the tick labels
        plt.rc('legend', fontsize=12)    # legend fontsize

        fig, ax = plt.subplots()
        plt.title(title)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)

        # print(self.history.results['WS-W'])
        for agent in agents:
            for metric in metrics:
                if metric == 'weak regret':
                    metric_index = 0
                elif metric == 'strong regret':
                    metric_index = 1
                elif metric == 'observations':
                    metric_index = 2
                else:
                    continue

                max_T = self.history.results[agent].shape[1]
                max_iterations = self.history.results[agent].shape[0]
                if T <= 0 or T > max_T:
                    T = max_T
                if iterations <= 0 or iterations > max_iterations:
                    iterations = max_iterations
                m = np.mean(np.cumsum(self.history.results[agent][:iterations, :T, metric_index], axis=1), axis=0)
                quantile_high = np.quantile(np.cumsum(self.history.results[agent][:iterations, :T, metric_index], axis=1), quantile, axis=0)
                quantile_low = np.quantile(np.cumsum(self.history.results[agent][:iterations, :T, metric_index], axis=1), 1-quantile, axis=0)
                line, = ax.plot(range(T), m)
                line.set_label(agent)

                if iterations > 1:
                    ax.fill_between(range(T), quantile_low, quantile_high, alpha=0.15)

                plt.legend(loc=2)

        if save_pdf:
            fig.savefig(self.files['pdf'] + file_suffix + '.pdf')
        if save_png:
            fig.savefig(self.files['png'] + file_suffix + '.png', dpi=300)

    def plot_arms(self, agents=[], T=-1, iterations=-1, save_png=True, save_pdf=True, title='', ylabel='Arm Count', xlabel='Arms', file_suffix=''):
        matplotlib.rcParams['mathtext.fontset'] = 'custom'
        matplotlib.rcParams['mathtext.fontset'] = 'stix'
        matplotlib.rcParams['font.family'] = 'STIXGeneral'

        plt.rc('axes', titlesize=16)     # fontsize of the axes title
        plt.rc('axes', labelsize=16)    # fontsize of the x and y labels
        plt.rc('xtick', labelsize=14)    # fontsize of the tick labels
        plt.rc('ytick', labelsize=14)    # fontsize of the tick labels
        plt.rc('legend', fontsize=12)    # legend fontsize

        fig, ax = plt.subplots()
        plt.title(title)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)

        # print(self.history.results['WS-W'])
        ag = 0
        for agent in agents:
            ag += 1
            x_offset = 1/(len(agents)+2)

            K = self.history.P.shape[0]
            Ks = np.array(range(K)) + 1

            iterations, T, _ = self.history.actions[agent].shape
            print(iterations, T)

            count = np.zeros((K, 2))

            for iteration in range(iterations):
                for t in range(T):
                    # print(int(self.history.actions[agent][iteration, t, 0]))
                    count[int(self.history.actions[agent][iteration, t, 0]), 0] += 1
                    count[int(self.history.actions[agent][iteration, t, 1]), 1] += 1

            count /= iterations

            bar = ax.bar(Ks - 0.5 + (ag+0.5)*x_offset, count[:, 0], x_offset, color=self.colors[ag-1])
            ax.bar(Ks - 0.5 + (ag+0.5)*x_offset, count[:, 1], x_offset, bottom=count[:, 0], color=self.colors[ag-1], alpha=0.2)
            bar.set_label(agent)

            plt.legend(loc=2)
        if save_pdf:
            fig.savefig(self.files['pdf'] + file_suffix + '.pdf')
        if save_png:
            fig.savefig(self.files['png'] + file_suffix + '.png', bbox_inches='tight', dpi=300)


class History:
    def __init__(self):
        self.results = {}
        self.actions = {}
