from typing import List

import numpy as np

from utils import solve_optim


class ResultItem:

    def __init__(self, final_answer: np.ndarray, nt: np.ndarray):
        self.final_answer = final_answer
        self.nt = nt
        self.stopping_time = self.nt.sum()


class ResultSummary:

    def __init__(self, experiments: List[ResultItem], delta, epsilon, mu, x1, x2):
        self.experiments = experiments
        self.n_runs = len(experiments)
        self.delta = delta
        self.eps = epsilon
        self.x1 = x1
        self.x2 = x2
        self.mu = mu
        self.n_arms = len(self.mu)

    def correctness(self):
        num_correct = 0

        for item in self.experiments:
            final_answer = item.final_answer
            if ((abs(self.x1[0] - final_answer[0]) <= self.eps and abs(self.x1[1] - final_answer[1]) <= self.eps) or
                    (abs(self.x2[0] - final_answer[0] <= self.eps) and abs(self.x2[3] - final_answer[1])) <= self.eps):
                num_correct += 1

        return num_correct / self.n_runs

    def sample_complexity(self):
        return [exp.stopping_time for exp in self.experiments]

    def lower_bound(self):
        divergence = solve_optim(w=np.array([0.5, 0.5, 0., 0.]),
                                 mu=self.mu,
                                 eps=self.eps)

        lb = np.log(1 / self.delta) / divergence

        return lb

    def get_nt_pulls(self, normalize=True):
        matrix = np.zeros((self.n_runs, self.n_arms))
        for i, exp in enumerate(self.experiments):
            if normalize:
                matrix[i] = exp.nt / exp.stopping_time
            else:
                matrix[i] = exp.nt

        return matrix

    def get_distance_from_w1_w2(self):
        res = []
        w1 = np.array([0.5, 0.5, 0.0, 0.0])
        w2 = np.array([0.0, 0.0, 0.5, 0.5])

        for exp in self.experiments:
            emp_w = exp.nt / exp.stopping_time
            d1 = np.linalg.norm(w1 - emp_w, ord=1)
            d2 = np.linalg.norm(w2 - emp_w, ord=1)
            res.append(min([d1, d2]))

        return np.array(res)
