import time
from mnist import MNIST
import pandas as pd
import numpy as np
import argparse

from tqdm import trange

from models.dbsl import DBSL
from models.oful import OFUL
from models.soful import SOFUL

import pickle as pkl
from scipy import stats
from line_profiler import profile


def groupby(features, labels):
    # features = stats.zscore(features)
    df = pd.DataFrame(np.hstack([labels[:, np.newaxis], features]))

    grouped = df.groupby(0)

    labels = []
    clusters = {}

    for label, group in grouped:
        labels.append(int(label))
        clusters[int(label)] = group.values[:, 1:]

    return labels, clusters


# @profile
def run():
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str)
    args = parser.parse_args()

    method = args.method

    mndata = MNIST("datasets/MNIST/raw")
    images, labels = mndata.load_training()
    features = np.array(images)
    features = np.divide(features, 255.0)
    labels = np.array(labels)

    labels, clusters = groupby(features, labels)

    arms = len(labels)
    d = features.shape[-1]

    T = 2000
    l = 20
    betas = [10**-2]
    lmds = [20]

    for beta in betas:
        for lmd in lmds:
            file_stem = "mnist_" + method + f",beta={beta},lmd={lmd}"
            if method != "oful":
                # l = args.l
                file_stem += f",l={l}"

            acc_regret = 0.0
            acc_regrets = []

            match method:
                case "oful":
                    bandit = OFUL(d, beta=beta, lmd=lmd)
                case "soful":
                    bandit = SOFUL(d, beta=beta, lmd=lmd, m=l)
                case "cbscfd":
                    bandit = SOFUL(d, beta=beta, lmd=lmd, m=l, robust=True)
                case "dbsl":
                    eps = 1000
                    file_stem += f",eps={eps}"
                    robust = True
                    if robust:
                        file_stem += ",robust"
                    bandit = DBSL(d, 2, eps, beta, lmd, robust=robust)

            observe = lambda arm, x: 1 if arm == 4 else 0

            start_time = time.process_time_ns()
            for i in trange(T, desc=file_stem):
                decision_set = np.zeros((arms, d))
                for key in clusters:
                    row_num = clusters[key].shape[0]
                    decision_index = np.random.choice(row_num)
                    decision_set[key] = clusters[key][decision_index]

                reward = bandit.fit(decision_set, observe=observe)
                best_rewards = 1
                regret = best_rewards - reward
                acc_regret += regret
                acc_regrets.append(acc_regret)

            end_time = time.process_time_ns()
            elapsed_time = end_time - start_time
            sum_time_second = elapsed_time // (10**6)
            print(sum_time_second)

            X = bandit.X

            results = {"acc_regrets": acc_regrets, "X": X, "time": sum_time_second}
            if method == "soful":
                results["deltas"] = bandit.deltas

            file_stem += "time"

            with open(f"results/{file_stem}.pkl", "wb") as f:
                pkl.dump(results, f)


if __name__ == "__main__":
    run()
