import time
import numpy as np
import torch
import argparse
from util import parse_bool, set_seed, make_parent_dir
from postprocessing import load_old_model
from rl import SequentialAlgWithContext, LinearGaussianContextTS
from rl import NeuralLinearGaussianContextTS_general
from rl import run_bandit, get_bandit_envs, get_bandit_envs_from_dgp
from rl import GreedySequentialWithContext
from rl import get_bandit_envs_from_dgp_withZ
from rl import LinUCBDisjoint
from rl import LinearGaussianContextTS_general
import os
from context_dgp_functions import CONTEXT_DGPs

def get_gaussian_prior_params(dgp, feats=False, prior_name=None):
    # Get the absolute path of the current file
    current_file_path = os.path.abspath(__file__)
    prior_str = '' if prior_name is None else '_' + prior_name
    # Get the directory of the current file
    current_directory = os.path.dirname(current_file_path)
    if feats:
        return torch.load(current_directory + f'/saved_params/{dgp}_feats{prior_str}.pt')

    return torch.load(current_directory + f'/saved_params/{dgp}{prior_str}.pt')



def get_article_ordering(seed, N):
    rng_tmp = np.random.default_rng(seed)
    article_ordering = np.arange(N)
    rng_tmp.shuffle(article_ordering)
    return article_ordering

def load_bandit_rewards(bandit_dir, all_bandit_envs, success_p_all):
    env_rewards_dict = {}
    action_arms_dict = {}
    for f in os.listdir(bandit_dir):
        idx = int(f.split('.')[0].split('=')[1])
        if idx >= len(all_bandit_envs): continue
        
        # verify environments are the same between loaded runs
        c = torch.load(bandit_dir + '/' + f)
        chosen_arms = all_bandit_envs[idx][1]

        for k,v in c.items():
            if isinstance(v, torch.Tensor):
                c[k] = v.detach().numpy()

        assert np.abs(all_bandit_envs[idx][1] - c['env_chosen_arms']).mean() == 0
        assert np.abs(success_p_all[chosen_arms] - c['env_click_rates']).mean() == 0
        if 'reward_dict' not in c.keys():
            print(c.keys())
        env_rewards_dict[idx] = c['reward_dict']['expected_rewards']
        action_arms_dict[idx] = c['reward_dict']['action_arms']
    missing = [idx for idx in range(len(all_bandit_envs)) if idx not in env_rewards_dict or idx not in action_arms_dict]
    if len(missing) > 0:
        raise ValueError(f'Missing idxs: {" ".join(missing)}')
    
    all_rewards = [ env_rewards_dict[idx] for idx in range(len(all_bandit_envs)) ]
    all_action_arms = [ action_arms_dict[idx] for idx in range(len(all_bandit_envs)) ]
    return {'rewards':np.array(all_rewards),'action_arms':all_action_arms}


def get_file_savename(args):
    dataset_str = f'dataset={args.dataset}' if args.dgp is None else f'dgp={args.dgp},dimX={args.X_dim}'
    if args.bandit_alg in ['greedy']:
        name = f'num_arms={args.num_arms},T={args.T},seed={args.seed},{dataset_str},alg={args.bandit_alg}'
    else:
        name = f'num_arms={args.num_arms},T={args.T},num_imagined={args.num_imagined},seed={args.seed},{dataset_str},alg={args.bandit_alg}'
    
    if args.bandit_alg == 'sequential':
        if args.randomly_break_ties:
            name += ',rand_break_ties'
    if args.t2p is not None and args.t2p.lower() != 'none':
        name += f',t2p={args.t2p}'
    if args.finite_horizon_alg and args.bandit_alg not in ['linearTSgeneral','linearTS11','linearTS1_0.25','linearTS11Z','linearTS1_0.25Z','linearfeatureTS11','linearfeatureTS1_0.25','greedy','linucb_disjoint','predetor']:
        name += ",finite_horizon_alg"
    if args.no_shuffle_boot:
        name += ',no_shuffle_boot'
    if args.use_bandit_split:
        name += ',bandit_split'
    if args.epsilon > 0 and args.bandit_alg=='greedy':
        name += f',eps={args.epsilon}'
    if args.bandit_alg == 'linucb_disjoint':
        name += f',alpha={args.alpha}'
    if args.bandit_alg=='sequential_context' and args.draw_X_samples:
        name += ',draw_X_samples'
    if args.prior_name is not None and 'general' in args.bandit_alg:
        name += f',prior_name={args.prior_name}'
    if 'linear' in args.bandit_alg:
        name += f',add_const={args.add_const_feature}'
    if args.cholesky and 'general' in args.bandit_alg:
        name += ',cholesky'

    name += f"/env_idx={args.env_idx}.pt"
    return name


def main():
    start = time.time()
    parser = argparse.ArgumentParser()
    # where to save outputs: save_dir is not given but a model_dir is given, we use that
    parser.add_argument('--save_dir', type=str, help='directory to save in', default=None)

    # load model
    parser.add_argument('--model_dir', type=str, help='directory for model to load', default=None)
    parser.add_argument('--filename', type=str, default=None)
    parser.add_argument('--num_imagined', type=int, default=100)
    parser.add_argument('--bandit_alg', type=str, default='sequential_context', 
                        choices=['sequential_context', 'sequential_context_ignore_x',
                                 'linearTS11','linearTS1_0.25','linearTS11Z','linearts11_0.25Z','linearfeatureTS11',
                                 'linearfeatureTS1_0.25','linucb_disjoint','predetor','linearTSgeneral','neural_linear_context_general']) 
    parser.add_argument('--randomly_break_ties', type=parse_bool, default=False) # this is mostly for sequential PSAR 
    # bandit env params
    parser.add_argument('--env_idx', type=int)
    parser.add_argument('--T', type=int, help='number of timesteps')
    parser.add_argument('--num_arms', type=int, help='number of bandit arms')
    parser.add_argument('--seed', type=int, default=23485223) # where would we use this. idk
    parser.add_argument('--dataset', type=str, default='val')
    parser.add_argument('--horizonDependent', type=int, default=0) # boolean
    
    # use dgp instead of dataset?
    parser.add_argument('--dgp', type=str, choices=[None] + list(CONTEXT_DGPs.keys()), default=None)
    parser.add_argument('--X_dim', type=int, default=1, help='dimension for dgp')

    parser.add_argument('--t2p', type=str, default=None)
    parser.add_argument('--finite_horizon_alg', type=parse_bool, default=False, help='finite horizon algorithm')

    parser.add_argument('--bandit_dir', type=str, default='bandit')
    parser.add_argument('--no_shuffle_boot', type=parse_bool, default=False) # for debugging bootstrap
    
    parser.add_argument('--use_bandit_split', type=parse_bool, default=False) # use bandit data split
    parser.add_argument('--verbose', type=parse_bool, default=False) # use bandit data split
    parser.add_argument('--alpha', type=float, default=1, help='linucb disjoint alpha')

    # for semisynthetic context
    parser.add_argument('--usable_Z_file', type=str, default=None)
    parser.add_argument('--generate_Z_file', type=str, default=None)

    parser.add_argument('--epsilon', type=float, default=0, help='epsilon for epsilon greedy only, and only for context')
    parser.add_argument('--draw_X_samples', type=parse_bool, default=True)
    parser.add_argument('--prior_name', type=str, default=None)
    parser.add_argument('--cholesky', type=parse_bool, default=False)
    parser.add_argument('--add_const_feature', type=parse_bool, default=True)

    args = parser.parse_args()
    print(args)
    args.context=True

    if args.dgp is None or args.dgp.lower() == 'none':
        args.dgp = None
    else:
        args.dataset = None

    if args.t2p is not None and args.t2p.lower() == 'none':
        args.t2p = None
        
    assert args.model_dir is not None 
    # save outputs in model_dir if save_dir is not provided
    # assert args.save_dir is not None or args.model_dir is not None

    if args.save_dir is None and args.model_dir is not None:
        args.save_dir = args.model_dir + '/' + args.bandit_dir + '/'
        if args.filename is not None and args.filename.lower() != 'none':
            args.save_dir = args.save_dir + args.filename + '/'

    # load model, and also click rates and embeddings
    assert args.model_dir is not None
    if args.filename is None:
        args.filename = 'best_loss.pt'
    model_path = args.model_dir + "/" + args.filename
    check = torch.load(model_path, map_location=torch.device('cpu'))
    config_path = args.model_dir + "/config.pt"
    config = torch.load(config_path, map_location=torch.device('cpu'))
    config.device = 'cpu'
    model = load_old_model(config, check['state_dict'], check)
    model.eval()

    assert \
        ((args.generate_Z_file is not None) and (args.usable_Z_file is not None)) or\
        ((args.generate_Z_file is None) and (args.usable_Z_file is None)) 
    # Load features where relevant 
    if args.dgp is not None and args.generate_Z_file is None:
        generate_fn = CONTEXT_DGPs[args.dgp]
        dgp_fn = lambda D, N, g: generate_fn(D=D,
                                             N=N,
                                             dimX=args.X_dim,
                                             ave_U=False,
                                             one_X_per_col=True,
                                             g=g)


        all_bandit_envs = get_bandit_envs_from_dgp(dgp_fn, args.num_arms, args.T, args.env_idx+1, seed=args.seed)
        bandit_env, data_dict = all_bandit_envs[args.env_idx]
        Z_representation = data_dict['Z']
        orig_click_rates = data_dict['click_rate'].numpy()
        val_batch_size = len(orig_click_rates)
        article_ordering = get_article_ordering(args.seed, val_batch_size)
        click_rates = orig_click_rates[article_ordering]


    elif args.generate_Z_file is None:
        # click rates for env
        click_rates = check[f'{args.dataset}_loss_dict']['click_rates']
    
        embed_path = args.model_dir + "/best_loss_row_embeds.pt"
        embeds = torch.load(embed_path, map_location=torch.device('cpu'))

        # shuffle bandit environment click rates
        orig_click_rates = click_rates.numpy()
        article_ordering = get_article_ordering(args.seed, len(orig_click_rates))
        click_rates = orig_click_rates[article_ordering]

        if model.z_encoder is not None:
            Z_representation = embeds[args.dataset][article_ordering]
        else:
            Z_representation = None
    
    # make bandit envs, then select the correct one
    if args.dgp is not None:
        if args.generate_Z_file is not None:
            generate_fn = CONTEXT_DGPs[args.dgp]
            dgp_fn = lambda D, N, g, Z: generate_fn(D=D,
                                             N=N,
                                             dimX=args.X_dim,
                                             ave_U=False,
                                             one_X_per_col=True,
                                             g=g,
                                             Z=Z)
            bandit_generate_Z = torch.load(args.generate_Z_file)
            assert args.usable_Z_file is not None
            bandit_usable_Z = torch.load(args.usable_Z_file)
                
            all_bandit_envs = get_bandit_envs_from_dgp_withZ(dgp_fn, args.num_arms, args.T, args.env_idx+1, seed=args.seed, context=True,
                    all_usable_Z=bandit_usable_Z['Z'],
                    all_generate_Z=bandit_generate_Z['Z'])
            bandit_env, data_dict = all_bandit_envs[args.env_idx]

            # more X's to pass around, for generating \hat \tau's for TS-Gen
            extra_bandit_envs = get_bandit_envs_from_dgp_withZ(dgp_fn, args.num_arms, args.T*2, args.env_idx+2, seed=args.seed, context=args.context,
                    all_usable_Z=bandit_usable_Z['Z'],
                    all_generate_Z=bandit_generate_Z['Z'])

            extra_bandit_env, extra_data_dict = extra_bandit_envs[args.env_idx+1]

        else:
            all_bandit_envs = get_bandit_envs_from_dgp(dgp_fn, args.num_arms, args.T, args.env_idx+1, seed=args.seed, context=True)
            bandit_env, data_dict = all_bandit_envs[args.env_idx]

            # more X's to pass around, for generating \hat \tau's for TS-Gen
            extra_bandit_envs = get_bandit_envs_from_dgp(dgp_fn, args.num_arms, args.T*2, args.env_idx+2, seed=args.seed, context=args.context)
            extra_bandit_env, extra_data_dict = extra_bandit_envs[args.env_idx+1]
            
    else:
        all_bandit_envs = get_bandit_envs(args.num_arms, args.T, args.env_idx+1, click_rates, seed=args.seed,
                                     horizonDependent=args.horizonDependent)
        bandit_env, chosen_arms = all_bandit_envs[args.env_idx]
        Z_representation = Z_representation[chosen_arms]


    file_savename = args.save_dir + '/' + get_file_savename(args)
    make_parent_dir(file_savename)
    loss_matrix = None

    try:
        assert args.dgp is not None
        # compute prediction loss matrix
        loss_matrix = model.eval_seq(Z_representation, bandit_env.X, bandit_env.potential_outcomes)

    except:
        print('no loss matrix')
    print('Make alg')

    extra_xgb_params = {}
    if args.bandit_alg=='xgb_depth2':
        extra_xgb_params = {'max_depth':2}

    if args.bandit_alg=='xgb_depth1':
        extra_xgb_params = {'max_depth':1}
    if args.bandit_alg == 'xgb_depth2_max50':
        extra_xgb_params = {'max_depth':2, 'n_estimators':50}

    if args.bandit_alg == 'linearTS11':
        bandit_alg = LinearGaussianContextTS(num_arms=args.num_arms, X=data_dict['X'], 
                                hyparam_dict={'lam':1, 'sig':1})
    elif args.bandit_alg == 'linearTS1_0.25':
        bandit_alg = LinearGaussianContextTS(num_arms=args.num_arms, X=data_dict['X'], 
                                hyparam_dict={'lam':1, 'sig':0.5})

    elif args.bandit_alg == 'linearTS11Z':
        bandit_alg = LinearGaussianContextTS(num_arms=args.num_arms, X=data_dict['X'], 
                                hyparam_dict={'lam':1, 'sig':1}, Z=data_dict['Z'])
    elif args.bandit_alg == 'linearTS1_0.25Z':
        bandit_alg = LinearGaussianContextTS(num_arms=args.num_arms, X=data_dict['X'], 
                                hyparam_dict={'lam':1, 'sig':0.5}, Z=data_dict['Z'])

    elif args.bandit_alg == 'linearfeatureTS11' or args.bandit_alg == 'linearfeatureTS1_0.25':
        assert hasattr(model, 'get_features') # only implemented for some, e.g. MarginalPredictorContext
        XZ_features = model.get_features(data_dict['Z'], data_dict['X'])
        bandit_env.X = XZ_features
        if args.bandit_alg == 'linearfeatureTS11':
            bandit_alg = LinearGaussianContextTS(num_arms=args.num_arms, X=XZ_features,
                                hyparam_dict={'lam':1, 'sig':1})
        else:
            bandit_alg = LinearGaussianContextTS(num_arms=args.num_arms, X=XZ_features,
                                hyparam_dict={'lam':1, 'sig':0.5}) 
            # noise variance: 0.25
            # prior variance: 1

    # linear TS
    elif args.bandit_alg == 'linearTSgeneral':
        gaussian_params = get_gaussian_prior_params(args.dgp, prior_name=args.prior_name)
        bandit_alg = LinearGaussianContextTS_general(num_arms=args.num_arms, X=data_dict['X'],
				**gaussian_params, add_const_feature=args.add_const_feature)
    # neural linear TS
    elif args.bandit_alg == 'neural_linear_context_general':

        # turn model into a featurizer
        top_layer_model = model.top_layer.model

        # Create a new sequent4ial model with the first 6 layers (indices 0 to 5)
        new_top_layer_model = torch.nn.Sequential(*list(top_layer_model.children())[:5])

        # Replace the original top_layer.model with the new one
        model.top_layer.model = new_top_layer_model
        print('adjusted model')
        print(model)

        gaussian_params = get_gaussian_prior_params(args.dgp, feats=True, prior_name=args.prior_name)
        bandit_alg = NeuralLinearGaussianContextTS_general(num_arms=args.num_arms, X=data_dict['X'], Z=data_dict['Z'], neural_model=model,
                                **gaussian_params, add_const_feature=args.add_const_feature, cholesky=args.cholesky)
    elif args.bandit_alg == 'greedy':
        with torch.no_grad():
            Z_rep = model.z_encoder(data_dict['Z'])
        bandit_alg = GreedySequentialWithContext(model, Z_rep, args.num_arms, args.T, data_dict['X'], args.epsilon)
    elif args.bandit_alg == 'predetor':
        with torch.no_grad():
            Z_rep = model.z_encoder(data_dict['Z'])
        bandit_alg = GreedySequentialWithContext(model, Z_rep, args.num_arms, args.T, data_dict['X'], tau=0.05)

    elif args.bandit_alg == 'linucb_disjoint':
        bandit_alg = LinUCBDisjoint(args.num_arms, data_dict['X'].shape[-1], args.alpha)
    else:
        if args.bandit_alg=='sequential_context':
            ignore_context = False
        elif args.bandit_alg=='sequential_context_ignore_x':
            ignore_context = True            
        else:
            raise ValueError(f'unrecognized bandit alg: {args.bandit_alg} with context')
        X_samples = None
        if args.draw_X_samples:
            X_samples = extra_data_dict['X']

        bandit_alg = SequentialAlgWithContext(model, data_dict['Z'], args.num_arms, args.T, data_dict['X'], 
                                                ignore_context=ignore_context, t2p=args.t2p.split('_')[0],
                                                num_imagined=args.num_imagined,
                                                finite_horizon_alg=args.finite_horizon_alg,
                                                no_shuffle_boot=args.no_shuffle_boot, extra_xgb_params=extra_xgb_params,
                                                X_samples=X_samples)             
        print(f"dim Z: {data_dict['Z'].shape}")
                
    print('run bandits')
    set_seed(args.seed)
    reward_dict = run_bandit(bandit_env, bandit_alg, args.T, return_extra=True, context=True, verbose=args.verbose)
    res = {
            'reward_dict': reward_dict,
            'loss_matrix': loss_matrix,
    }
    if args.dgp is None:
        res['env_chosen_arms'] = chosen_arms
        res['env_article_ordering'] = article_ordering
        res['orig_click_rates'] = orig_click_rates,
        res['env_click_rates'] = click_rates[chosen_arms]
    else:
        # mostly to check across methods
        res['success_p'] = bandit_env.success_p
    torch.save(res, file_savename)
    end = time.time()
    print(f'Saved to {file_savename}')
    print(f'Total time: {(end-start)} seconds = {(end-start)/60:0.2f} minutes')


if __name__ == "__main__":
    main()
