from bandit_env import BanditEnv, BanditEnvVec, BanditTransformerController, EmpMeanPolicy, PessMeanPolicy, UCBPolicy, OptPolicy, ThompsonSamplingPolicy2, ETC
from bandit_env import TopKBanditEnv, TopKBanditTransformerController, GreedyOptPolicy, TopKRandCommitPolicy, TopKEpsGreedy, LinUCB
from bandit_env import LinearBanditEnv
import torch
import numpy as np
import scipy
import scipy.stats
import matplotlib.pyplot as plt
from IPython import embed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



def deploy_online(env, controller, **kwargs):
    H = kwargs['H']
    horizon = kwargs['horizon']

    # fix the traj thing.
    rollin_xs = torch.zeros((1, horizon, env.dx)).float().to(device)
    rollin_us = torch.zeros((1, horizon, env.du)).float().to(device)
    rollin_xps = torch.zeros((1, horizon, env.dx)).float().to(device)
    rollin_rs = torch.zeros((1, horizon, 1)).float().to(device)
    

    cum_means = []
    for h in range(horizon):
        batch = {
            'rollin_xs': rollin_xs[:,:h,:],
            'rollin_us': rollin_us[:,:h,:],
            'rollin_xps': rollin_xps[:,:h,:],
            'rollin_rs': rollin_rs[:,:h,:],
        }
        
        controller.set_batch(batch)
        xs_lnr, us_lnr, xps_lnr, rs_lnr = env.deploy(controller)

        rollin_xs[0,h,:] = torch.tensor(xs_lnr[0])
        rollin_us[0,h,:] = torch.tensor(us_lnr[0])
        rollin_xps[0,h,:] = torch.tensor(xps_lnr[0])
        rollin_rs[0,h,:] = torch.tensor(rs_lnr[0])
        
        a = us_lnr.flatten()
        mean = env.get_arm_value(a)

        cum_means.append(mean)

    meta = {
        'xs': rollin_xs.detach().cpu().numpy()[0],
        'us': rollin_us.detach().cpu().numpy()[0],
        'xps': rollin_xps.detach().cpu().numpy()[0],
        'rs': rollin_rs.detach().cpu().numpy()[0],
    }

    return np.array(cum_means), meta



def online(eval_trajs, model, **kwargs):

    var = kwargs['var']
    H = kwargs['H']
    n_eval = kwargs['n_eval']
    horizon = kwargs['horizon']
    envname = kwargs['envname']
    k = kwargs['k']
    prior_mean = kwargs['prior_mean']
    prior_var = kwargs['prior_var']
    test_var = kwargs['test_var']

    all_means = {}



    for i_eval in range(n_eval):
        print(f"Eval traj: {i_eval}")
        traj = eval_trajs[i_eval]    
        means = traj['means']
        if envname in ['bandit', 'linear_bandit']:
            
            if envname == 'bandit':
                env = BanditEnv(means, horizon, var=test_var)    
            elif envname == 'linear_bandit':
                env = LinearBanditEnv(traj['theta'], traj['arms'], horizon, var=test_var)
            else:
                raise NotImplementedError

            controllers = {
                'opt': OptPolicy(env),
                'Lnr': BanditTransformerController(model, sample=True),
                'Emp': EmpMeanPolicy(env, online=True),
                'UCB1.0': UCBPolicy(env, const = 1.0),
                'TS_nws': ThompsonSamplingPolicy2(env, std=var, sample=True, prior_mean=prior_mean, prior_var=prior_var, warm_start=False),
            }
            if envname == 'linear_bandit':
                controllers['LinUCB'] = LinUCB(env, horizon)

        elif envname == 'bandit_topk':
            env = TopKBanditEnv(means, horizon, var=test_var, k=k)
            controllers = {
                'opt': OptPolicy(env),
                'lnr': TopKBanditTransformerController(model, sample=True, k=k),
                'rnd': TopKRandCommitPolicy(env, k, horizon),
                'eps': TopKEpsGreedy(env, k, horizon),
                'ucb': LinUCB(env, k),
            }

        else:
            raise NotImplementedError

        cum_means = { k : deploy_online(env, v, **kwargs)[0] for k, v in controllers.items() }
        for key in cum_means.keys():
            if key not in all_means.keys():
                all_means[key] = []
            all_means[key].append(cum_means[key])


    all_means = { k : np.array(v) for k, v in all_means.items() }
    all_means_diff = { k : all_means['opt'] - v for k, v in all_means.items() }

    means = { k : np.mean(v, axis=0) for k, v in all_means_diff.items() }
    sems = { k : scipy.stats.sem(v, axis=0) for k, v in all_means_diff.items() }

    
    for key in means.keys():
        if key == 'opt':
            plt.plot(means[key], label=key, linestyle='--', color='black', linewidth=2)
            plt.fill_between(np.arange(horizon), means[key] - sems[key], means[key] + sems[key], alpha=0.2, color='black')
        else:
            plt.plot(means[key], label=key)
            plt.fill_between(np.arange(horizon), means[key] - sems[key], means[key] + sems[key], alpha=0.2)

    
    plt.legend()
    plt.yscale('log')
    plt.ylim(bottom=1e-4)
    plt.xlabel('Episodes')
    plt.ylabel('Suboptimality')
    plt.title('Online Evaluation')
    

def deploy_online_numpy_vec(vec_env, controller, **kwargs):
    H = kwargs['H']
    horizon = kwargs['horizon']

    num_envs = vec_env.num_envs
    dx = vec_env.envs[0].dx
    du = vec_env.envs[0].du

    # fix the traj thing.
    rollin_xs = np.zeros((num_envs, horizon, dx))
    rollin_us = np.zeros((num_envs, horizon, du))
    rollin_xps = np.zeros((num_envs, horizon, dx))
    rollin_rs = np.zeros((num_envs, horizon, 1))

    cum_means = []
    for h in range(horizon):
        print("Step: ", h)
        batch = {
            'rollin_xs': rollin_xs[:, :h, :],
            'rollin_us': rollin_us[:, :h, :],
            'rollin_xps': rollin_xps[:, :h, :],
            'rollin_rs': rollin_rs[:, :h, :],
        }

        controller.set_batch_numpy_vec(batch)
        xs_lnr, us_lnr, xps_lnr, rs_lnr = vec_env.deploy(controller)

        rollin_xs[:, h, :] = xs_lnr
        rollin_us[:, h, :] = us_lnr
        rollin_xps[:, h, :] = xps_lnr
        rollin_rs[:, h, :] = rs_lnr[..., None]

        mean = vec_env.get_arm_value(us_lnr)

        cum_means.append(mean)

    meta = {
        'xs': rollin_xs,
        'us': rollin_us,
        'xps': rollin_xps,
        'rs': rollin_rs,
    }

    return np.stack(cum_means, axis=1), meta


def online_vec(eval_trajs, model, **kwargs):

    var = kwargs['var']
    test_var = kwargs['test_var']
    H = kwargs['H']
    n_eval = kwargs['n_eval']
    horizon = kwargs['horizon']
    envname = kwargs['envname']
    k = kwargs['k']
    prior_mean = kwargs['prior_mean']
    prior_var = kwargs['prior_var']

    all_means = {}

    envs = []

    for i_eval in range(n_eval):
        print(f"Eval traj: {i_eval}")
        traj = eval_trajs[i_eval]
        means = traj['means']
        if envname in ['bandit', 'linear_bandit']:

            if envname == 'bandit':
                env = BanditEnv(means, horizon, var=test_var)
            elif envname == 'linear_bandit':
                env = LinearBanditEnv(traj['theta'], traj['arms'], horizon, var=test_var)
            else:
                raise NotImplementedError

            controllers = {
                'opt': OptPolicy(env),
                # 'Lnr': BanditTransformerController(model, sample=True),
                # 'Emp': EmpMeanPolicy(env, online=True),
                # 'UCB1.0': UCBPolicy(env, const = 1.0),
                # 'UCB0.5': UCBPolicy(env, const = 0.5),
                # 'UCB0.3': UCBPolicy(env, const = .3),
                # 'ETC': ETC(env, horizon),
                # 'TS_nws': ThompsonSamplingPolicy2(env, std=var, sample=True, prior_mean=prior_mean, prior_var=prior_var, warm_start=False),
            }
            if envname == 'linear_bandit':
                controllers['LinUCB'] = LinUCB(env, horizon)

            envs.append(env)

        elif envname == 'bandit_topk':
            env = TopKBanditEnv(means, horizon, var=var, k=k)
            controllers = {
                'opt': OptPolicy(env),
                'lnr': TopKBanditTransformerController(model, sample=True, k=k),
                'rnd': TopKRandCommitPolicy(env, k, horizon),
                'eps': TopKEpsGreedy(env, k, horizon),
                'ucb': LinUCB(env, k),
            }

        else:
            raise NotImplementedError

        cum_means = { k : deploy_online(env, v, **kwargs)[0] for k, v in controllers.items() }
        for key in cum_means.keys():
            if key not in all_means.keys():
                all_means[key] = []
            all_means[key].append(cum_means[key])

    if envname == 'linear_bandit' or envname == 'bandit':
        vec_env = BanditEnvVec(envs)

        controller = BanditTransformerController(
            model,
            sample=True,
            batch_size=len(envs))
        cum_means = deploy_online_numpy_vec(vec_env, controller, **kwargs)[0]
        assert cum_means.shape[0] == n_eval
        all_means['Lnr'] = cum_means

        controller = EmpMeanPolicy(
            envs[0],
            online=True,
            batch_size=len(envs))
        cum_means = deploy_online_numpy_vec(vec_env, controller, **kwargs)[0]
        assert cum_means.shape[0] == n_eval
        all_means['Emp'] = cum_means

        controller = UCBPolicy(
            envs[0],
            const=1.0,
            batch_size=len(envs))
        cum_means = deploy_online_numpy_vec(vec_env, controller, **kwargs)[0]
        assert cum_means.shape[0] == n_eval
        all_means['UCB1.0'] = cum_means

        controller = ThompsonSamplingPolicy2(
            envs[0],
            std=var,
            sample=True,
            prior_mean=prior_mean,
            prior_var=prior_var,
            warm_start=False,
            batch_size=len(envs))
        cum_means = deploy_online_numpy_vec(vec_env, controller, **kwargs)[0]
        assert cum_means.shape[0] == n_eval
        all_means['TS'] = cum_means

    all_means = { k : np.array(v) for k, v in all_means.items() }
    all_means_diff = { k : all_means['opt'] - v for k, v in all_means.items() }

    means = { k : np.mean(v, axis=0) for k, v in all_means_diff.items() }
    sems = { k : scipy.stats.sem(v, axis=0) for k, v in all_means_diff.items() }


    for key in means.keys():
        if key == 'opt':
            plt.plot(means[key], label=key, linestyle='--', color='black', linewidth=2)
            plt.fill_between(np.arange(horizon), means[key] - sems[key], means[key] + sems[key], alpha=0.2, color='black')
        else:
            plt.plot(means[key], label=key)
            plt.fill_between(np.arange(horizon), means[key] - sems[key], means[key] + sems[key], alpha=0.2)


    plt.legend()
    plt.yscale('log')
    plt.ylim(bottom=1e-4)
    plt.xlabel('Episodes')
    plt.ylabel('Suboptimality')
    plt.title('Online Evaluation')

    return all_means
