from lqr_env import LQREnv, RandController, LQRController
import lqr_env
import bandit_env
import numpy as np
import os
import pickle
from IPython import embed
from collect_data import rollin, rollin_bandit, generate_histories, generate_bandit_histories, generate_bandit_histories_for_arms, generate_topk_bandit_histories
from collect_data import generate_dr_histories, generate_dr_histories_for_goals, generate_dr_stitch_histories_for_goals, rollin_dr
from collect_data import generate_linear_bandit_histories




if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--envs", type=int, required=False, default=100, help="Envs")
    parser.add_argument("--H", type=int, required=False, default=10, help="Context horizon")
    parser.add_argument("--dim", type=int, required=False, default=1, help="Dimension")
    parser.add_argument("--k", type=int, required=False, default=1, help="Top k subset")
    parser.add_argument("--var", type=float, required=False, default=0.0, help="Bandit arm variance")
    parser.add_argument("--cov", type=float, required=False, default=0.0, help="Coverage")
    parser.add_argument("--env", type=str, required=True, help="Environment")
    parser.add_argument("--mode", type=int, required=False, default=0, help="Mode")
    parser.add_argument("--orig", type=int, required=False, default=2, help="Top k subset")
    parser.add_argument("--lin_d", type=int, required=False, default=1, help="Linear Dimension")


    args = vars(parser.parse_args())
    print("Args:")
    print(args)

    n_envs = args['envs']
    H = args['H']
    envname = args['env']
    var = args['var']
    cov = args['cov']

    dx = args['dim']
    du = args['dim']
    dim = args['dim']
    k = args['k']
    mode = args['mode']
    orig = args['orig']
    lin_d = args['lin_d']
    warm_start = False
    
    # if envname == 'bandit' and mode == 1:
    #     trajs = generate_bandit_histories_special(n_envs, 1, 1, H, dim, var=var, cov=cov, orig=orig)
    #     filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}_var{var}_cov{cov}_orig{orig}.pkl'

    # trajs = generate_histories(n_envs, 1, 1, H)
    if envname == 'bandit':  
        trajs = generate_bandit_histories(n_envs, 1, 1, H, dim, var=var, cov=cov, orig=orig)
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}_var{var}_cov{cov}_orig{orig}.pkl'

    elif envname == 'bandit_thompson':
        trajs = generate_bandit_histories(n_envs, 1, 1, H, dim, var=var, cov=cov, type='bernoulli')
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}_var{var}_cov{cov}.pkl'
    
    elif envname == 'bandit_ood':
        envs = list(range(dim // 2, dim)) * (n_envs // (dim // 2))
        trajs = generate_bandit_histories_for_arms(envs, 1, 1, H, dim, var=var, cov=cov)
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}_var{var}_cov{cov}.pkl'

    elif envname == 'bandit_topk':
        trajs = generate_topk_bandit_histories(n_envs, 1, 1, H, dim, k=k, var=var)
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}_var{var}_k{k}.pkl'

    elif envname == 'linear_bandit':
        trajs = generate_linear_bandit_histories(n_envs, 1, 1, H, dim, lin_d, var=var)
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}_dlin{lin_d}_var{var}_ws{warm_start}.pkl'

    elif envname == 'darkroom':
        trajs = generate_dr_histories(n_envs, 1, 1, H, dim)
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}.pkl'
    
    elif envname == 'darkroom_heldout':
        goals = np.array([[(j, i) for i in range(dim)] for j in range(dim)]).reshape(-1, 2)
        np.random.RandomState(seed=0).shuffle(goals)
        train_test_split = int(.8 * len(goals))
        test_goals = goals[train_test_split:]
        test_goals = np.repeat(test_goals, n_envs // (dim * dim), axis=0)

        trajs = generate_dr_histories_for_goals(test_goals, 1, 1, H, dim)
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}.pkl'

    elif envname == 'darkroom_stitch':
        goals = [np.array([dim // 2, dim - 1]), np.array([dim - 1, dim // 2])]
        test_goals = np.repeat(goals, n_envs // len(goals), axis=0)
        trajs = generate_dr_stitch_histories_for_goals(test_goals, 1, 1, H, dim, eval=True)
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}.pkl'

    else:
        raise NotImplementedError       

    if not os.path.exists('datasets'):
        os.makedirs('datasets', exist_ok=True)
    with open(filepath, 'wb') as file:
        pickle.dump(trajs, file)
    
    print(f"Saved to {filepath}.")
