from agent.MOLB_TS import MOLB_TS
from agent.MOLB_UCB import MOLB_UCB
from agent.MOLB_Greedy import MOLB_Greedy
from agent.PFIwR import PFIwR
from environment.multiobjective_env import Multiobjective_env
from environment.mo_context_linear_env import MO_Context_Linear_env
import json
import pickle
import sys
sys.path.append('../')
from tqdm import trange
import time
import numpy as np

class Run_algo:
    def __init__(self, K, d, L = 1, T = 4000, interval = 100, Rep = 10):
        self.K = K
        self.d = d
        self.L = L
        self.T = T
        self.interval = interval
        self.Rep = Rep

        self.save_t = T//interval
        self.cumul_regret_v1 = np.zeros((Rep, self.save_t+1))
        self.cumul_regret_v2 = np.zeros((Rep, self.save_t+1))
        
        self.total_reward = np.zeros((L,Rep, self.save_t+1))
        self.theta_err = np.array([np.zeros((self.save_t+1,L)) for r in range(Rep)])
        self.elapsed_time = np.zeros((Rep,self.save_t+1))
        self.selected_rounds = np.zeros((Rep,self.save_t+1,K))
    
    def run(self, agent, env, r):
        regret_v1 = 0.0
        regret_v2 = 0.0
        t = 0
        t_err = []
        s_rounds = np.zeros(self.K)
        # print(env.view_context()[0])
        env.warm_up()

        print("Experiment %d with model %s" % (r, agent.name))
        for s in trange(1, self.save_t+1):
            
            for _ in range(self.interval):
                # print(env.view_context()[0])
                start = time.time()

                contexts = env.view_context()
                idx = agent.select_ac(contexts)
                reward = env.action_reward(idx)
                agent.update(reward, contexts[idx])
                t += time.time()-start
                
                self.total_reward[:,r,-1] += reward
                regret_v1 += env.pareto_regret_old[idx]
                regret_v2 += env.pareto_regret_ours[idx]
                s_rounds[idx] += 1
                env.update_env()
            if agent.name == "PFIwR":
                err = np.linalg.norm(agent.theta_hat.T-env.theta_list, axis = 1)
            else:
                err = np.linalg.norm(agent.theta_hat-env.theta_list, axis = 1)
            self.elapsed_time[r,s] = t
            self.cumul_regret_v1[r,s] = regret_v1
            self.cumul_regret_v2[r,s] = regret_v2
            self.theta_err[r,s] = err
            self.selected_rounds[r,s,:] = s_rounds
            self.total_reward[:,r,s] = self.total_reward[:,r,-1]

        print("Total time : ", self.elapsed_time[r,-1])
        print("Total Regret_v1:", self.cumul_regret_v1[r,:][-1])
        print("Total Regret_v2:", self.cumul_regret_v2[r,:][-1])
        print("Thera Error:", self.theta_err[r,-1])
        print("Total Reward:",self.total_reward[:,r,-1])

    def save(self, agent):
        results =  {'model': agent.name,
                    'total_reward' : self.total_reward.tolist(),
                    'regrets_v1' : self.cumul_regret_v1.tolist(),
                    'regrets_v2' : self.cumul_regret_v2.tolist(),
                    'theta_err' : self.theta_err.tolist(),
                    'selected_rounds' : self.selected_rounds.tolist(),
                    'time': self.elapsed_time.tolist()}
        return results


if __name__ == "__main__":

    """
    Hyperparameters
    """
    ####################################33
    # K = [50, 50, 100, 100, 100, 200, 200]
    # d = [5,  10,   5,  10,  15,  10,  15]
    # K = [50, 100, 200]
    # d = [5, 10,  15]
    K = [50, 100]
    d = [5, 10]
    L = 4 #numbers of objective
    
    # num_samples = 1
    num_samples = int(1+np.ceil(np.log(L)/np.log(1/0.85)))

    rep = 2
    num_agent = 4
    trial = 0
    #################################
    r = len(K)
    run_algo = []
    for i in range(r):
        run_algo.append([Run_algo(K[i], d[i], L = L,  T = 10000, interval = 20,Rep = rep) for _ in range(num_agent)])
        
    for i in range(r):
        print(f"start Experiment K : {K[i]}, d : {d[i]}, L : {L}, trial : {trial}")
        for r in range(rep):
            env = MO_Context_Linear_env(K = K[i], d = d[i], num_obj = L, sig = 1, version = r, fixed_context=False)
            # env.warm_up()
            # agent_PFI = PFIwR(env.view_context(), L, None)
            agent_TS = MOLB_TS(d = d[i], m = L, num_samples = 1, name = "MOLB_TS")
            agent_TS_opt = MOLB_TS(d = d[i], m = L, num_samples = num_samples, name = "MOLB_TS_opt")
            agent_UCB = MOLB_UCB(d = d[i], m = L)
            agent_Greedy = MOLB_Greedy(d = d[i], m = L, epsilon = 0.05)

            agents = [agent_TS, agent_TS_opt, agent_UCB, agent_Greedy]
            for j, agent in enumerate(agents):
                run_algo[i][j].run(agent, env, r)
        
        results = dict()
        for j, agent in enumerate(agents):
            result = run_algo[i][j].save(agent)
            results[agent.name] = result

        with open('./results/Experiment_K%d_d%d_L%d_trial_%d.pkl' % (K[i], d[i], L, trial), 'wb') as f:
            pickle.dump(results, f)
    print("finished")
    

