import numpy as np
from experiment_setting import Experiment
from util import feasibility_check, optimal_solution, indicator, update_ratio, beta, glrt
from sampling_rule import USR, ESR, ASR, SEQSR, FWSR

class Agent:
    def __init__(self, experiment_id, n0, PCS, delta, obsve_budget, sampling_rule):
        self.experiment_id = experiment_id
        self.sampling_rule = sampling_rule
        self.EXP = Experiment(self.experiment_id)
        self.dist = self.EXP.dist
        self.s = self.EXP.s
        self.k = self.EXP.k
        self.m = self.EXP.m
        self.n0 = n0
        self.PCS = PCS
        self.delta = delta
        self.obsve_budget = obsve_budget
        self.F = self.EXP.F
        self.b = self.EXP.b
        self.X_mean = self.EXP.X_mean
        self.X_variance = self.EXP.X_variance
        self.feasibility = self.EXP.feasibility
        self.opt_solution = self.EXP.opt_solution
        self.phi = self.EXP.phi
        self.fws_hist = np.full((self.s, self.k, self.m+1), 1.0 / (self.s * self.k * (self.m+1)))

    def reset(self):
        self.X_mean_esti = np.zeros((self.s, self.k))
        self.feasibility_esti = np.ones((self.s, self.k))
        self.opt_solution_esti = np.zeros(self.s, dtype=int)
        self.F_esti = np.zeros((self.s, self.k, self.m))
        self.alternative_count = np.zeros((self.s, self.k, self.m + 1))
        self.ratio_hist = np.zeros((self.obsve_budget, self.s, self.k, self.m + 1))

    def sample(self):

        if self.sampling_rule == "USR":
            next_task, next_arm, next_cons = USR(self.alternative_count)

        elif self.sampling_rule == "ESR":
            next_task, next_arm, next_cons = ESR(self.dist, self.s, self.k, self.m, self.n0, self.feasibility_esti,
                                                 self.phi, self.F_esti, self.X_mean_esti, self.EXP.X_variance,
                                                 self.opt_solution_esti, self.alternative_count)
        elif self.sampling_rule == "ASR":
            next_task, next_arm, next_cons = ASR(self.s, self.k, self.m, self.n0, self.feasibility_esti,
                                                 self.phi, self.F_esti, self.X_mean_esti, self.EXP.X_variance,
                                                 self.opt_solution_esti, self.alternative_count)
        elif self.sampling_rule == "SEQSR":
            next_task, next_arm, next_cons = SEQSR(self.dist, self.s, self.k, self.m, self.n0, self.feasibility_esti,
                                                 self.phi, self.F_esti, self.X_mean_esti, self.EXP.X_variance,
                                                 self.opt_solution_esti, self.alternative_count)
        elif self.sampling_rule == "FWSR":
            next_task, next_arm, next_cons, fws_ratio = FWSR(self.dist, self.s, self.k, self.m, self.n0, self.feasibility_esti,
                                                 self.phi, self.F_esti, self.X_mean_esti, self.EXP.X_variance,
                                                 self.opt_solution_esti, self.alternative_count, self.fws_hist)
            self.fws_hist = fws_ratio

        return next_task, next_arm, next_cons

    def step(self, task, arm, cons):
        self.ratio_hist[int(np.sum(self.alternative_count))] = update_ratio(self.alternative_count)
        self.alternative_count[task, arm, cons] += 1
        observation = self.EXP.generate_samples(self.dist, task, arm, cons)
        if cons == self.m:
            self.X_mean_esti[task, arm] += (observation - self.X_mean_esti[task, arm]) / self.alternative_count[task, arm, cons]
        else:
            self.F_esti[task, arm, cons] += (indicator(observation, self.b[cons]) - self.F_esti[task, arm, cons]) / self.alternative_count[task, arm, cons]
        self.feasibility_esti = feasibility_check(self.s, self.k, self.F_esti, self.phi)
        self.opt_solution_esti = optimal_solution(self.s, self.X_mean_esti, self.feasibility_esti)

    def select(self):
        return self.opt_solution_esti

    def stop(self):
        if np.min(self.alternative_count) <= self.n0:
            return False
        elif self.PCS:
            if np.sum(self.alternative_count) >= self.obsve_budget:
                return True
            else:
                return False
        else:
            curr_ratio = self.alternative_count / np.sum(self.alternative_count)
            val = glrt(self.s, self.k, self.m, self.dist, self.X_mean_esti, self.F_esti, self.opt_solution_esti, self.X_variance, self.feasibility_esti, self.phi, curr_ratio)
            t = np.sum(self.alternative_count)
            if t * val > beta(t, self.delta):
                return True
            else:
                return False














