# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

import numpy as np
import utils
from ray.util.multiprocessing import Pool
from Instance import Instance
from Algorithm import Algorithm
    

class Worker:
    def __init__(self, W, Z, theta, Gamma, instance_type, sigma_x, sigma_y, 
                 sampling_method, elimination_method, design_method, 
                 delta, alg_type, burn_in_sampling_method, burn_in_length, horizon, reuse_gamma, key):
        
        self.W = W
        self.Z = Z
        self.theta = theta
        self.Gamma = Gamma
        self.instance_type = instance_type
        self.sigma_x = sigma_x
        self.sigma_y  = sigma_y 
        self.sampling_method = sampling_method
        self.elimination_method = elimination_method
        self.design_method = design_method
        self.delta = delta
        self.alg_type = alg_type
        self.burn_in_sampling_method = burn_in_sampling_method 
        self.burn_in_length = burn_in_length
        self.horizon = horizon
        self.key = key
        self.reuse_gamma = reuse_gamma
        self.save_results = {self.key:{}}
        
    def run_individual(self, idx):
        
        np.random.seed(idx)

        instance = Instance(W=self.W, Z=self.Z, theta=self.theta, Gamma=self.Gamma, 
                            instance_type=self.instance_type, sigma_x=self.sigma_x, sigma_y=self.sigma_y)
        self.instance = instance
        self.Gamma = instance.Gamma

        alg = Algorithm(instance=instance, sampling_method=self.sampling_method, elimination_method=self.elimination_method, 
                        design_method=self.design_method, delta=self.delta, alg_type=self.alg_type, burn_in_sampling_method=self.burn_in_sampling_method, burn_in_length=self.burn_in_length, 
                        reuse_gamma=self.reuse_gamma, horizon=self.horizon)
        if self.sampling_method == 'ucb':
            alg.run_ucb()   
        else:
            alg.run_phased_alg()

        
        return alg
    
    
    def run(self, num_repeats, aggregate=True, num_cpu=None):
        if num_cpu is None:
            pool = Pool()
        else:
            pool = Pool(processes=num_cpu)
            
        jobs = []
        for idx in range(num_repeats):
            params = (idx)
            # jobs.append(pool.apply_async(self.run_individual, args=(params,)))
            jobs.append(pool.apply_async(self.run_individual, (idx,)))
            
        self.results = [job.get() for job in jobs]
        self.instance = self.results[0].instance
        self.Gamma = self.results[0].instance.Gamma
        if aggregate:
            self.aggregate_results()
        
        
        
    def aggregate_results(self):
        
        if self.alg_type == 'fixed_budget':
            
            self.accuracy_mean_oracle = np.mean(np.vstack([self.results[i].acc_oracle for i in range(len(self.results))]), axis=0)
            self.accuracy_ste_oracle = np.std(np.vstack([self.results[i].acc_oracle for i in range(len(self.results))]), axis=0)/np.sqrt(len(self.results))
            
            self.accuracy_mean_ols = np.mean(np.vstack([self.results[i].acc_ols for i in range(len(self.results))]), axis=0)
            self.accuracy_ste_ols = np.std(np.vstack([self.results[i].acc_ols for i in range(len(self.results))]), axis=0)/np.sqrt(len(self.results))
            
            self.accuracy_mean_two_sls = np.mean(np.vstack([self.results[i].acc_two_sls for i in range(len(self.results))]), axis=0)
            self.accuracy_ste_two_sls = np.std(np.vstack([self.results[i].acc_two_sls for i in range(len(self.results))]), axis=0)/np.sqrt(len(self.results))
            
            self.accuracy_mean_pseudo_two_sls = np.mean(np.vstack([self.results[i].acc_pseudo_two_sls for i in range(len(self.results))]), axis=0)
            self.accuracy_ste_pseudo_two_sls = np.std(np.vstack([self.results[i].acc_pseudo_two_sls for i in range(len(self.results))]), axis=0)/np.sqrt(len(self.results))


            self.simple_regret_mean_oracle = np.mean(np.vstack([self.results[i].simple_regret_oracle for i in range(len(self.results))]), axis=0)
            self.simple_regret_ste_oracle = np.std(np.vstack([self.results[i].simple_regret_oracle for i in range(len(self.results))]), axis=0)/np.sqrt(len(self.results))

            self.simple_regret_mean_ols = np.mean(np.vstack([self.results[i].simple_regret_ols for i in range(len(self.results))]), axis=0)
            self.simple_regret_ste_ols = np.std(np.vstack([self.results[i].simple_regret_ols for i in range(len(self.results))]), axis=0)/np.sqrt(len(self.results))

            self.simple_regret_mean_two_sls = np.mean(np.vstack([self.results[i].simple_regret_two_sls for i in range(len(self.results))]), axis=0)
            self.simple_regret_ste_two_sls = np.std(np.vstack([self.results[i].simple_regret_two_sls for i in range(len(self.results))]), axis=0)/np.sqrt(len(self.results))

            self.simple_regret_mean_pseudo_two_sls = np.mean(np.vstack([self.results[i].simple_regret_pseudo_two_sls for i in range(len(self.results))]), axis=0)
            self.simple_regret_ste_pseudo_two_sls = np.std(np.vstack([self.results[i].simple_regret_pseudo_two_sls for i in range(len(self.results))]), axis=0)/np.sqrt(len(self.results))


            self.bias_mean_oracle = np.mean(np.vstack([self.results[i].bias_oracle for i in range(len(self.results))]), axis=0)
            self.bias_ste_oracle = np.std(np.vstack([self.results[i].bias_oracle for i in range(len(self.results))]), axis=0)/np.sqrt(len(self.results))

            self.bias_mean_ols = np.mean(np.vstack([self.results[i].bias_ols for i in range(len(self.results))]), axis=0)
            self.bias_ste_ols = np.std(np.vstack([self.results[i].bias_ols for i in range(len(self.results))]), axis=0)/np.sqrt(len(self.results))

            self.bias_mean_two_sls = np.mean(np.vstack([self.results[i].bias_two_sls for i in range(len(self.results))]), axis=0)
            self.bias_ste_two_sls = np.std(np.vstack([self.results[i].bias_two_sls for i in range(len(self.results))]), axis=0)/np.sqrt(len(self.results))

            self.bias_mean_pseudo_two_sls = np.mean(np.vstack([self.results[i].bias_pseudo_two_sls for i in range(len(self.results))]), axis=0)
            self.bias_ste_pseudo_two_sls = np.std(np.vstack([self.results[i].bias_pseudo_two_sls for i in range(len(self.results))]), axis=0)/np.sqrt(len(self.results))
            
            self.save_results[self.key]['accuracy_mean_oracle'] = self.accuracy_mean_oracle
            self.save_results[self.key]['accuracy_ste_oracle'] = self.accuracy_ste_oracle
            self.save_results[self.key]['accuracy_mean_ols'] = self.accuracy_mean_ols
            self.save_results[self.key]['accuracy_ste_ols'] = self.accuracy_ste_ols
            self.save_results[self.key]['accuracy_mean_two_sls'] = self.accuracy_mean_two_sls
            self.save_results[self.key]['accuracy_ste_two_sls'] = self.accuracy_ste_two_sls
            self.save_results[self.key]['accuracy_mean_pseudo_sls'] = self.accuracy_mean_pseudo_two_sls
            self.save_results[self.key]['accuracy_ste_pseudo_sls'] = self.accuracy_ste_pseudo_two_sls
            
            self.save_results[self.key]['simple_regret_mean_oracle'] = self.simple_regret_mean_oracle
            self.save_results[self.key]['simple_regret_ste_oracle'] = self.simple_regret_ste_oracle
            self.save_results[self.key]['simple_regret_mean_ols'] = self.simple_regret_mean_ols
            self.save_results[self.key]['simple_regret_ste_ols'] = self.simple_regret_ste_ols
            self.save_results[self.key]['simple_regret_mean_two_sls'] = self.simple_regret_mean_two_sls
            self.save_results[self.key]['simple_regret_ste_two_sls'] = self.simple_regret_ste_two_sls
            self.save_results[self.key]['simple_regret_mean_pseudo_sls'] = self.simple_regret_mean_pseudo_two_sls
            self.save_results[self.key]['simple_regret_ste_pseudo_sls'] = self.simple_regret_ste_pseudo_two_sls
            
            self.save_results[self.key]['bias_mean_oracle'] = self.bias_mean_oracle
            self.save_results[self.key]['bias_ste_oracle'] = self.bias_ste_oracle
            self.save_results[self.key]['bias_mean_ols'] = self.bias_mean_ols
            self.save_results[self.key]['bias_ste_ols'] = self.bias_ste_ols
            self.save_results[self.key]['bias_mean_two_sls'] = self.bias_mean_two_sls
            self.save_results[self.key]['bias_ste_two_sls'] = self.bias_ste_two_sls
            self.save_results[self.key]['bias_mean_pseudo_sls'] = self.bias_mean_pseudo_two_sls
            self.save_results[self.key]['bias_ste_pseudo_sls'] = self.bias_ste_pseudo_two_sls
            
        elif self.alg_type == 'fixed_conf':
            self.accuracy_mean = np.mean([self.results[i].accuracy for i in range(len(self.results))])
            self.accuracy_ste = np.std([self.results[i].accuracy for i in range(len(self.results))])/np.sqrt(len(self.results))
            self.sample_complexity_mean = np.mean([self.results[i].stop_time for i in range(len(self.results))])
            self.sample_complexity_ste = np.std([self.results[i].stop_time for i in range(len(self.results))])/np.sqrt(len(self.results))    
            
            self.save_results[self.key]['accuracy_mean'] = self.accuracy_mean
            self.save_results[self.key]['accuracy_ste'] = self.accuracy_ste
            self.save_results[self.key]['sample_complexity_mean'] = self.sample_complexity_mean
            self.save_results[self.key]['sample_complexity_ste'] = self.sample_complexity_ste


