# Copyright (C) king.com Ltd 2025
# License: Apache 2.0

import os.path
import argparse

import matplotlib.pyplot as plt
import torch
import numpy as np
import pickle
from sys import exit

from bandits import UCBBandit, FactoredUCBBandit
from bandits import EpsGreedyBandit, FactoredEpsGreedyBandit
from bandits import TSBandit, FactoredTSBandit


def synthetic_reward(full_action_vec, target_vec):
    return (torch.abs(torch.from_numpy(full_action_vec) - torch.from_numpy(target_vec)) < 0.01).float().mean().item() + torch.normal(torch.tensor([0.0]), torch.tensor([0.05])).item()


def simulate(bandit, num_steps=1000, target_vec=np.array([5, 5, 5]), r_fun=synthetic_reward, seed=0):
    rewards = []
    regrets = []
    losses = []
    for step in range(num_steps):
        full_action = bandit.select_action()

        r = r_fun(full_action.detach().cpu().numpy(), target_vec)
        loss = bandit.update(full_action, r)

        rewards.append(r)
        losses.append(loss)
        regret = 1 - r
        regrets.append(regret)
        if (step + 1) % (num_steps / 10) == 0:
            avg = sum(rewards[-20:]) / 20
            print(f"Seed {seed} | {bandit.name} | {bandit.num_segments} segments | {bandit.num_choices_per_segment} choices | Step {step+1}: Avg reward (last 20): {avg:.3f}")

    return rewards, regrets, losses


def compute_stats(all_rewards, all_regrets, smoothing_window=100):
    all_rewards = np.array(all_rewards)
    mean_rewards = np.mean(all_rewards, axis=0)
    std_rewards = np.std(all_rewards, axis=0)

    all_regrets = np.array(all_regrets)
    cum_regrets = np.cumsum(all_regrets, axis=1)
    mean_regrets = np.mean(cum_regrets, axis=0)
    std_regrets = np.std(cum_regrets, axis=0)

    smoothed_mean_rewards = np.convolve(mean_rewards, np.ones(smoothing_window)/smoothing_window, mode='valid')
    smoothed_std_rewards = np.convolve(std_rewards, np.ones(smoothing_window)/smoothing_window, mode='valid')
    smoothed_mean_regrets = np.convolve(mean_regrets, np.ones(smoothing_window)/smoothing_window, mode='valid')

    return (mean_rewards, std_rewards, mean_regrets, std_regrets,
            smoothed_mean_rewards, smoothed_std_rewards, smoothed_mean_regrets)


def compute_interaction_term(factored_bandit, bandit_family):

    all_diffs = []
    for i in range(len(factored_bandit.all_seen_actions)):  # iterate over whole history
        full_action = factored_bandit.all_seen_actions[i]
        reward = factored_bandit.all_seen_rewards[i]

        full_pred_i = 0
        segment_preds_i = []
        for s_idx in range(factored_bandit.num_segments):  # predict reward with each segment model
            action_s = full_action[s_idx]

            if bandit_family == "thompson_sampling":
                a_one_hot = factored_bandit.action_to_one_hot(action_s)
                pred_s = np.dot(factored_bandit.mus[s_idx], a_one_hot)
                full_pred_i += pred_s
                segment_preds_i.append(pred_s)

            elif bandit_family == "ucb":
                action_s = torch.from_numpy(factored_bandit.action_to_one_hot(action_s)).float().to(factored_bandit.device)
                action_s = action_s.unsqueeze(0)
                pred_s = factored_bandit.models[s_idx](action_s)
                full_pred_i += pred_s.item()
                segment_preds_i.append(pred_s.item())

            else:
                action_s = torch.from_numpy(np.array([action_s])).float().reshape(-1, 1).to(factored_bandit.device)
                pred_s = factored_bandit.models[s_idx].forward(action_s).squeeze(0)
                full_pred_i += pred_s.item()
                segment_preds_i.append(pred_s.item())

        full_pred_i *= 1 / factored_bandit.num_segments  # final pred for action is mean of segment preds

        diff_i = np.abs(reward - full_pred_i)  # error our bandit makes, interaction term for sample i
        all_diffs.append(diff_i)

    mean_interaction_term = np.mean(all_diffs)
    print("Mean interaction term:", mean_interaction_term)
    return mean_interaction_term


def run_bandit_comparison(
        target_vec,
        num_segments,
        choice_per_segment,
        num_steps,
        n_seeds=3,
        lr=0.005,
        epsilon=0.1,
        beta=2.0,
        update_on_buffer=True,
        reset_param=False,
        buffer_capacity=1000,
        batch_size=64,
        updates_per_sample=1,
        smoothing_window=None,
        r_fun=synthetic_reward,
        return_stats=False,
        bandit_family="ucb",
        hide_plots=True,
):
    all_rewards_factored = []
    all_regrets_factored = []
    all_losses_factored = []
    all_interaction_terms = []

    all_rewards_flat = []
    all_regrets_flat = []
    all_losses_flat = []

    assert max(target_vec) <= choice_per_segment, "target vector must contain choices in 1, ..., choice_per_segment"

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    for seed in range(n_seeds):

        print(f"Running factored bandit (family = {bandit_family})...")
        if bandit_family == "ucb":
            factored_bandit = FactoredUCBBandit(
                beta=beta,
                update_on_buffer=update_on_buffer,
                reset_param=reset_param,
                num_segments=num_segments,
                num_choices_per_segment=choice_per_segment,
                lr=lr,
                updates_per_sample=updates_per_sample,
                buffer_capacity=buffer_capacity,
                batch_size=batch_size,
            ).to(device)
        elif bandit_family == "eps-greedy":
            factored_bandit = FactoredEpsGreedyBandit(
                epsilon=epsilon,
                update_on_buffer=update_on_buffer,
                reset_param=reset_param,
                num_segments=num_segments,
                num_choices_per_segment=choice_per_segment,
                lr=lr,
                updates_per_sample=updates_per_sample,
                buffer_capacity=buffer_capacity,
                batch_size=batch_size,
            ).to(device)
        elif bandit_family == "thompson_sampling":
            factored_bandit = FactoredTSBandit(
                context_dim=1,
                num_segments=num_segments,
                num_choices_per_segment=choice_per_segment,
            )
        else:
            raise NotImplementedError(f"Bandit family '{bandit_family}' not implemented")

        rewards, regrets, losses = simulate(factored_bandit, num_steps=num_steps, target_vec=target_vec, r_fun=r_fun, seed=seed)
        all_rewards_factored.append(rewards)
        all_regrets_factored.append(regrets)
        all_losses_factored.append(losses)

        # compute the interaction term of our semi-factored bandit as the mean prediction error
        mean_interaction_term = compute_interaction_term(factored_bandit, bandit_family)
        all_interaction_terms.append(mean_interaction_term)

        print(f"Running flat bandit (family = {bandit_family})...")
        if bandit_family == "ucb":
            flat_bandit = UCBBandit(
                beta=beta,
                update_on_buffer=update_on_buffer,
                reset_param=reset_param,
                num_segments=num_segments,
                num_choices_per_segment=choice_per_segment,
                lr=lr,
                updates_per_sample=updates_per_sample,
                buffer_capacity=buffer_capacity,
                batch_size=batch_size,
            ).to(device)
        elif bandit_family == "eps-greedy":
            flat_bandit = EpsGreedyBandit(
                epsilon=epsilon,
                update_on_buffer=update_on_buffer,
                reset_param=reset_param,
                num_segments=num_segments,
                num_choices_per_segment=choice_per_segment,
                lr=lr,
                updates_per_sample=updates_per_sample,
                buffer_capacity=buffer_capacity,
                batch_size=batch_size,
            ).to(device)
        elif bandit_family == "thompson_sampling":
            flat_bandit = TSBandit(
                context_dim=num_segments,
                num_segments=num_segments,
                num_choices_per_segment=choice_per_segment,
            )
        else:
            raise NotImplementedError(f"Bandit family '{bandit_family}' not implemented")

        rewards, regrets, losses = simulate(flat_bandit, num_steps=num_steps, target_vec=target_vec, r_fun=r_fun, seed=seed)
        all_rewards_flat.append(rewards)
        all_regrets_flat.append(regrets)
        all_losses_flat.append(losses)

        results_dir = "results"
        if not os.path.exists(results_dir):
            os.makedirs(results_dir)

    # compute stats and plot
    if smoothing_window is None:
        smoothing_window = int(num_steps / 10)
    (
        mean_rewards_factored, std_rewards_factored,
        mean_regrets_factored, std_regrets_factored,
        smoothed_mean_rewards_factored, smoothed_std_rewards_factored,
        smoothed_mean_regrets_factored
    ) = compute_stats(all_rewards_factored, all_regrets_factored, smoothing_window)

    (
        mean_rewards_flat, std_rewards_flat,
        mean_regrets_flat, std_regrets_flat,
        smoothed_mean_rewards_flat, smoothed_std_rewards_flat,
        smoothed_mean_regrets_flat
    ) = compute_stats(all_rewards_flat, all_regrets_flat, smoothing_window)

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    # plot rewards
    ax[0].plot(smoothed_mean_rewards_factored, label="Factored bandit")
    ax[0].fill_between(
        np.arange(len(smoothed_mean_rewards_factored)),
        smoothed_mean_rewards_factored - smoothed_std_rewards_factored,
        smoothed_mean_rewards_factored + smoothed_std_rewards_factored,
        alpha=0.2
    )

    ax[0].plot(smoothed_mean_rewards_flat, label="Flat bandit")
    ax[0].fill_between(
        np.arange(len(smoothed_mean_rewards_flat)),
        smoothed_mean_rewards_flat - smoothed_std_rewards_flat,
        smoothed_mean_rewards_flat + smoothed_std_rewards_flat,
        alpha=0.2
    )

    ax[0].set_xlabel('Round')
    ax[0].set_ylabel("Reward")
    ax[0].axhline(1, label="Optimal", c="k", ls="--")
    ax[0].legend()

    # plot regrets
    ax[1].plot(mean_regrets_factored, label="Factored bandit")
    ax[1].fill_between(
        np.arange(len(mean_regrets_factored)),
        mean_regrets_factored - std_regrets_factored,
        mean_regrets_factored + std_regrets_factored,
        alpha=0.2
    )

    ax[1].plot(mean_regrets_flat, label="Flat bandit")
    ax[1].fill_between(
        np.arange(len(mean_regrets_flat)),
        mean_regrets_flat - std_regrets_flat,
        mean_regrets_flat + std_regrets_flat,
        alpha=0.2
    )

    c = 1
    reference_curve_sublinear = c * np.sqrt(np.arange(1, num_steps + 1))
    ax[1].plot(reference_curve_sublinear, label=f"Reference: O(√t)", linestyle='--', color='tab:green')
    reference_curve_linear = c * np.arange(1, num_steps + 1)
    ax[1].plot(reference_curve_linear, label=f"Reference: O(t)", linestyle='--', color='tab:red')

    ax[1].set_xlabel('Round')
    ax[1].set_ylabel("Cumulative regret")
    ax[1].legend()

    mean_interaction_term = np.mean(all_interaction_terms)
    std_interaction_term = np.std(all_interaction_terms)

    plt.suptitle(
        f"Segments: {num_segments}, Choices per segment: {choice_per_segment} | Target sequence: {target_vec} | interaction_term: {np.round(mean_interaction_term, decimals=3)}")
    plt.savefig(f"{results_dir}/{bandit_family}_segments_{num_segments}_choices_{choice_per_segment}.png")
    if hide_plots:
        plt.close()
    else:
        plt.show()
        plt.close()

    if return_stats:
        return {
            "bandit_family": bandit_family,
            "num_segments": num_segments,
            "num_choices_per_segment": choice_per_segment,
            "target_vec": target_vec,
            "num_steps": num_steps,
            "buffer_capacity": buffer_capacity,
            "updates_per_sample": updates_per_sample,
            "batch_size": batch_size,
            "lr": lr,
            "reset_param": reset_param,

            "mean_rewards_factored": mean_rewards_factored,
            "std_rewards_factored": std_rewards_factored,
            "mean_regrets_factored": mean_regrets_factored,
            "std_regrets_factored": std_regrets_factored,
            "all_losses_factored": all_losses_factored,
            "all_rewards_factored": all_rewards_factored,
            "all_regrets_factored": all_regrets_factored,
            "mean_interaction_term": mean_interaction_term,
            "std_interaction_term": std_interaction_term,

            "mean_rewards_flat": mean_rewards_flat,
            "std_rewards_flat": std_rewards_flat,
            "mean_regrets_flat": mean_regrets_flat,
            "std_regrets_flat": std_regrets_flat,
            "all_losses_flat": all_losses_flat,
            "all_rewards_flat": all_rewards_flat,
            "all_regrets_flat": all_regrets_flat,
        }
    else:
        return {}


def generate_sweeps(
        num_segments=(2, 3, 4, 5,),
        num_choices=(3, 5, 7, 10),
):
    all_configs = {}
    for n_seg in num_segments:
        for n_choice in num_choices:
            bandit_size = n_choice ** n_seg

            if str(bandit_size) not in all_configs.keys():
                all_configs[str(bandit_size)] = {
                    "n_segments": n_seg,
                    "n_choices": n_choice,
                    "steps": 2000,
                    "target_vec": np.random.randint(1, n_choice + 1, size=n_seg),
                }

    print(f"Generated {len(all_configs.keys())} configs")
    return all_configs


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--bandit_family", type=str, default="thompson_sampling", help="Which bandit family to test. One of ['eps-greedy', 'thompson_sampling', 'ucb']")
    parser.add_argument("--num_rounds", type=int, default=2000, help="Number of bandit rounds.")
    args = parser.parse_args()

    stats = run_bandit_comparison(
        target_vec=np.array([4, 3, 2]),
        num_steps=args.num_rounds,
        buffer_capacity=args.num_rounds,
        num_segments=3,
        choice_per_segment=5,
        updates_per_sample=100,
        reset_param=True,
        update_on_buffer=True,
        return_stats=True,
        n_seeds=2,
        bandit_family=args.bandit_family,
        beta=1.0,
        hide_plots=False,
        smoothing_window=100,
        batch_size=args.num_rounds,
    )
    exit(0)

    sweeps = generate_sweeps()
    for size, config in sweeps.items():
        bandit_family = args.bandit_family
        print(f"Running bandit with size {size}: {config}")

        stats = run_bandit_comparison(
            target_vec=config["target_vec"],
            num_steps=config["steps"],
            buffer_capacity=config["steps"],
            num_segments=config["n_segments"],
            choice_per_segment=config["n_choices"],
            updates_per_sample=100,
            reset_param=True,
            update_on_buffer=True,
            return_stats=True,
            n_seeds=3,
            bandit_family=bandit_family,
            beta=1.0,
            hide_plots=True,
            batch_size=config["steps"]
        )

        # expand stats dict with config
        for key, value in config.items():
            stats[key] = value
        stats["problem_size"] = size

        # save stats to disk
        with open(f"results/{bandit_family}_segments_{config['n_segments']}_choices_{config['n_choices']}.pkl", "wb") as f:
            pickle.dump(stats, f)

    exit(0)

