import numpy as np
from TS import TS
from CORSA import CORSA
import abc
from experiment_setting import Experiment
from util import beta, glrt, update_ratio

class Algorithm(metaclass=abc.ABCMeta):
    def __init__(self, experiment_id, rho, policy, budget):
        self.experiment_id = experiment_id
        self.rho = rho
        self.EXP = Experiment(self.experiment_id, self.rho)
        self.policy = policy
        self.budget = budget
        self.average_ratio = np.ones(self.EXP.k) / self.EXP.k

    def reset(self):
        self.mean_estimation = np.zeros(self.EXP.k)
        self.alternative_count = np.zeros(self.EXP.k)
        self.best_alternative = np.argmax(self.EXP.mean)
        self.samples_set = self.offline_samples()
        self.ratio_hist = np.zeros((self.budget, self.EXP.k))

    @abc.abstractmethod
    def step(self, alternative):
        pass

    def offline_samples(self):
        cov = self.EXP.cov
        samples = np.random.multivariate_normal(mean=self.EXP.mean, cov=cov, size=300000)
        return samples

class Agent(Algorithm):
    def __init__(self, experiment_id, rho, policy, n0, PCS, budget, delta) -> None:
        super().__init__(experiment_id, rho, policy, budget)
        self.n0 = n0
        self.PCS = PCS
        self.delta = delta

    def sample(self):
        next_alternative = None
        if self.policy == "TS":
            next_alternative = TS(self.mean_estimation, self.EXP.variance, self.alternative_count, self.n0)
        elif self.policy == "CORSA":
            next_alternative = CORSA(self.mean_estimation, self.EXP.variance, self.rho, self.alternative_count, self.n0)
        return next_alternative

    def select(self):
        selected_alternative = np.random.choice(np.where(self.mean_estimation == np.max(self.mean_estimation))[0])
        return selected_alternative

    def step(self, alternative):
        if self.PCS:
            self.ratio_hist[int(np.sum(self.alternative_count))] = update_ratio(self.alternative_count)
        self.alternative_count[alternative] += 1
        observation = self.generate_sample(alternative, self.alternative_count[alternative])
        self.mean_estimation[alternative] += (observation - self.mean_estimation[alternative]) / self.alternative_count[alternative]

    def generate_sample(self, alternative, index):
        sample = self.samples_set[int(index), alternative]
        return sample

    def stop(self):

        if self.PCS:
            if np.sum(self.alternative_count) >= self.budget:
                return True
            else:
                return False
        else:
            if np.sum(self.alternative_count) < self.EXP.k * self.n0:
                return False
            else:
                curr_ratio = self.alternative_count / np.sum(self.alternative_count)
                val = glrt(self.EXP.k, self.mean_estimation, self.EXP.variance, self.rho, curr_ratio)
                t = np.sum(self.alternative_count)
                if t * val > beta(t, self.delta, self.rho):
                    return True
                else:
                    return False












    



