import numpy as np
from scipy import linalg
import scipy.io
from tqdm import tqdm
import pickle as pkl

from models.dbsl import DBSL
from models.oful import OFUL
import argparse

from models.soful import SOFUL
from line_profiler import profile


@profile
def run():
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str)
    parser.add_argument("-l", type=int, help="Sketch size", required=False)
    parser.add_argument("-T", type=int, help="Time", default=2000, required=False)
    parser.add_argument("-a", type=int, help="coefficient", default=4, required=False)
    parser.add_argument("-d", type=int, help="demision", default=500, required=False)
    args = parser.parse_args()

    method = args.method
    T = args.T
    d = args.d
    arms = 100
    betas = np.logspace(-4, 0, 5, 10)
    lmds = 2 * np.logspace(-4, 4, 9, 10)
    ls = [400]

    R = 1
    a = args.a

    # if T = 1000:
    #     dataset = scipy.io.loadmat("datasets/synthetic.mat")["X"]

    theta = np.random.multivariate_normal(np.zeros(d), np.eye(d))
    theta = theta / linalg.norm(theta)

    for beta in betas:
        for lmd in lmds:
            for l in ls:
                dataset = a * np.random.multivariate_normal(
                    np.zeros(d), R * np.eye(d), (1000, arms)
                )
                file_stem = (
                    "synthetic_" + method + f",d={d},beta={beta},lmd={lmd},T={T},a={a}"
                )
                if method == "soful" or method == "cbscfd":
                    # l = args.l
                    file_stem += f",l={l}"

                acc_regret = 0.0
                acc_regrets = []

                match method:
                    case "oful":
                        bandit = OFUL(d, beta=beta, lmd=lmd)
                    case "soful":
                        bandit = SOFUL(d, beta=beta, lmd=lmd, m=l)
                    case "cbscfd":
                        bandit = SOFUL(d, beta=beta, lmd=lmd, m=l, robust=True)
                    case "dbsl":
                        robust = True
                        if robust:
                            file_stem += ",robust"
                        bandit = DBSL(d, 4, 50000, beta=beta, lmd=lmd, robust=robust)

                etas = np.random.normal(0, R, arms)
                # etas = np.random.normal(0, R)
                observe = lambda arm, x: x @ theta + etas[arm]

                for i in tqdm(range(T), desc=file_stem):
                    if i != 0 and i % 1000 == 0:
                        dataset = a * np.random.multivariate_normal(
                            np.ones(d), R * np.eye(d), (1000, arms)
                        )
                    decision_set = dataset[i % 1000]
                    reward = bandit.fit(decision_set, observe=observe)
                    real_rewards = decision_set @ theta + etas
                    best_real_reward = np.max(real_rewards)
                    regret = best_real_reward - reward
                    acc_regret += regret
                    acc_regrets.append(acc_regret)

                X = bandit.X

                results = {"acc_regrets": acc_regrets, "X": X}
                if method == "soful":
                    results["deltas"] = bandit.deltas

                with open(f"results/{file_stem}.pkl", "wb") as f:
                    pkl.dump(results, f)

        #     break
        # break


if __name__ == "__main__":
    run()
