from algorithms import *
import numpy as np
import matplotlib.pyplot as plt
import time
import os

T = 10000
rep = 20
sigma = 1
delta = 0.01
seed = 1234
instance = 'Gauss'

for d in [10, 20, 40]:
    for K in [10, 100, 1000]:

        S = 1

        linucb = LinUCB(d, sigma, delta, S)
        lints = LinTS(d, sigma, delta, S)

        greedy = Greedy(d)

        MGucb5 = INFEX(d, LinUCB, sigma, delta, S, 5)
        MGts5 = INFEX(d, LinTS, sigma, delta, S, 5)

        MGucb10 = INFEX(d, LinUCB, sigma, delta, S, 10)
        MGts10 = INFEX(d, LinTS, sigma, delta, S, 10)

        MGucb20 = INFEX(d, LinUCB, sigma, delta, S, 20)
        MGts20 = INFEX(d, LinTS, sigma, delta, S, 20)

        MGucb100 = INFEX(d, LinUCB, sigma, delta, S, 100)
        MGts100 = INFEX(d, LinTS, sigma, delta, S, 100)

        ols = OLSBandit(d, 1 / (24 * d * d), np.sqrt(d) / (2 * K) )
        eps = epsGreedy(d)

        algs = [linucb, MGucb5, MGucb20, MGucb100, lints, MGts5, MGts20, MGts100, greedy, ols, eps]

        regret_data = []

        time_data = []

        for r in range(1, rep + 1):
            np.random.seed(seed * r)
            random.seed(seed * r)

            theta = np.random.normal(0, 1, d)
            theta /= np.sqrt(theta @ theta)

            regrets_nrun = []
            time_nrun = []
            armset = np.random.normal(0, np.sqrt(1 / 2 / d), (K, d))
            norm = np.sum(armset * armset, axis = 1)
            armset /= np.expand_dims(np.maximum(norm, 1), 1)
            armset = np.tile(armset, (T, 1, 1))
            
            rewards = armset @ theta
            regrets = np.max(rewards, axis = 1, keepdims = True) - rewards

            for i, alg in enumerate(algs):
                np.random.seed(seed * r)
                random.seed(seed * r)
                alg.reset()
                regret_1run = []

                start_time = time.process_time()
                for t in range(T):
                    arm = alg.select_ac(armset[t])
                    rwd = random.choices([1, -1], cum_weights = [(rewards[t, arm] + 1)/2, 1])[0]
                    alg.update(rwd)
                    regret_1run.append(regrets[t, arm])
                elapsed_time = time.process_time() - start_time
                time_nrun.append(elapsed_time)

                regrets_nrun.append(regret_1run)
            
            regret_data.append(regrets_nrun)

            time_data.append(time_nrun)

        regret_data = np.cumsum(np.array(regret_data), axis = 2)

        avg_regret = np.mean(regret_data, axis = 0)
        std_regret = np.std(regret_data, axis = 0)

        time_data = np.array(time_data)
        avg_time = np.mean(time_data, axis = 0)
        std_time = np.std(time_data, axis = 0)

        os.system(f"mkdir d{d}K{K}T{T} -p")

        filename = f"d{d}K{K}T{T}/d{d}K{K}T{T}"

        np.savetxt(f"{filename}.csv", np.concatenate((avg_regret, std_regret), axis = 0))

        np.savetxt(f"{filename}_time.csv", np.concatenate((avg_time, std_time), axis = 0))