import torch
import copy
import os
import argparse
from argparse import Namespace
from util import parse_bool, parse_int_list, set_seed

from train_fns import get_val_loss, get_model_and_optimizer_context
from dataset_context import get_loaders_synthetic_with_context

def get_row_embeds(config, model, loader, device):
    with torch.no_grad():
        row_embeds = []
        for batch in loader:
            for k,v in batch.items():
                batch[k] = v.to(device)
                row_embeds.append(batch['Z'])
            
    return torch.concatenate(row_embeds)

def get_row_embeds_fname(out_dir, prefix='best_loss_'):
    return out_dir + f'/{prefix}row_embeds.pt'

def get_predictions_fname(out_dir, prefix='best_loss_'):
    return out_dir + f'/{prefix}predictions.pt'

def get_posterior_samples_fname(out_dir, prefix='best_loss_'):
    return out_dir + f'/{prefix}posterior_samples.pt'


def save_row_embeds(config, model, loader_dict, out_dir, device, loader_names, prefix='best_loss_', recalc=False):
    save_fname = get_row_embeds_fname(out_dir, prefix)
    if not recalc and os.path.exists(save_fname):
        res = torch.load(save_fname, map_location='cpu')
    else:
        res = {}
    with torch.no_grad():
        for loader_name in loader_names:
            if loader_name in res.keys(): continue
            print(f'Saving row embeds for {loader_name}')
            res[loader_name] = get_row_embeds(config, model, loader_dict[loader_name+'_loader'], device)
    torch.save(res, save_fname)
    return res


def save_model_predictions(model, loader_dict, out_dir, device, loader_names, prefix='best_loss_', 
                           recalc=False, embed_data=False, 
                           use_X_model=False):
    save_fname = get_predictions_fname(out_dir, prefix)
    if not recalc and os.path.exists(save_fname):
        res = torch.load(save_fname, map_location='cpu')
    else:
        res = {}
    for loader_name in loader_names:
        if loader_name in res.keys(): continue
        print(f'Saving model loss for {loader_name}')
        loader = loader_dict[loader_name+'_loader']
        res[loader_name] = get_val_loss(model, loader, device, embed_data=embed_data,
                                       use_X_model=use_X_model)
    torch.save(res, save_fname)
    return res


def get_device(gpu):
    if gpu is not None and int(gpu) >= 0:
        return torch.device(f'cuda:{gpu}')
    else:
        return torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def load_old_model(config, sd, check=None):

    model, optimizer = get_model_and_optimizer_context(config)
    model.to('cpu')
    model.load_state_dict(sd)

    return model


def do_postprocessing(args):
    print(f'ARGS: {args}')

    device = args.device
    check = torch.load(args.run_dir + '/best_loss.pt', map_location='cpu')
    
    config = copy.deepcopy(check['config'])
    if not hasattr(config, 'embed_data_dir'):
        setattr(config, 'embed_data_dir', False)
    setattr(config, 'batch_size', args.batch_size)
    setattr(config, 'device', args.device)
    model = load_old_model(config, check['state_dict'], check)
    set_seed(config.seed)
    
    loaders = get_loaders_synthetic_with_context(config, train_deterministic_row_order=True, extras=True)

    model.to(device)
    model.eval()

    recalc = args.postproc_force_recalc

    all_loader_names = [x.split('_loader')[0] for x in loaders.keys() if x.endswith('_loader')]

    if not hasattr(config, 'embed_data_dir'):
        config.embed_data_dir=False

    # this will probably do some unnecessary computation, e.g. on val set,
    # where those outputs are already in e.g. best_loss.pt (but I think we don't care)
    predictions = save_model_predictions(model, loaders, 
            args.run_dir, device, all_loader_names, recalc=recalc, 
            embed_data = config.embed_data_dir,
            use_X_model=config.use_X_model) 
    print(predictions['train']['theta_hats'].shape)

    row_embeds = save_row_embeds(config, model, loaders, 
                args.run_dir, device, all_loader_names, recalc=recalc)

def add_default_postproc_params(parser):
    parser.add_argument('--postproc_force_recalc', type=parse_bool, default=True) # true for actual usage
    parser.add_argument('--post_sample_all_num_prev_obs', type=parse_int_list, 
            help='an integer or a list of integers separated by commas', 
            default=[0,1,2,5,10,25])
    parser.add_argument('--post_sample_num_repetitions', type=int, default=250)
    parser.add_argument('--post_sample_num_imagined', type=int, default=500)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--run_dir', type=str, help='directory with model outputs')
    parser.add_argument('--gpu', type=int, default=None)
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--wandb_entity', default='ps-autoregressive')
    add_default_postproc_params(parser)
    args = parser.parse_args()

    import wandb
    wandb.login()
    wandb.init(project='postprocessing', entity=args.wandb_entity)
    
    device = get_device(args.gpu)
    args.device = device
    print(f'Arg gpu: {args.gpu}, Device: {device}')
    
    do_postprocessing(args)

if __name__ == "__main__":
    main()
