import time
import sys
import numpy as np
import argparse
import torch
import wandb
from transformers import get_scheduler
import logging
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

# Imports 
from util import make_parent_dir, set_seed, parse_bool
from dataset_context import get_loaders_synthetic_with_context
from train_fns import get_val_loss, get_model_and_optimizer_context, loss_from_loss_matrix
from postprocessing import do_postprocessing, add_default_postproc_params

##############################################################################
# Functions for parsing arguments and training
##############################################################################

def get_argparser():
    parser = argparse.ArgumentParser()
    
    # File Saving -------------------------------------
    parser.add_argument('--save_name', type=str, help='save name for project')
    parser.add_argument('--prefix', type=str, help='prefix for descriptive_name for a run (within a project)', default=None)
    parser.add_argument('--data_dir', type=str, help='directory with data for training')
    parser.add_argument('--embed_data_dir', type=parse_bool, help='data has embeddings dump', default=False)
    parser.add_argument('--extra_eval_data', type=str, help='location of extra eval dataset (optional)', default=None)
    parser.add_argument('--wandb_user', type=str)
   
    # General Training -------------------------------------
    parser.add_argument('--learning_rate', type=float, default=1e-2)
    parser.add_argument('--weight_decay', type=float, default=1e-2)
    parser.add_argument('--epochs', type=int, default=5)
    parser.add_argument('--gpu', type=int, default=None)
    parser.add_argument('--batch_size', type=int, default=32) # note that this means rows OR obs
    parser.add_argument('--eval_batch_size', type=int, default=32) # note that this means rows OR obs
    parser.add_argument('--marginal_vs_sequential', type=str, choices=['sequential','marginal'])
    parser.add_argument('--seed', type=int, default=2340923)
    parser.add_argument('--onelayer', type=parse_bool, default=False)
    parser.add_argument('--MLP_width', type=int, default=50)
    parser.add_argument('--MLP_layer', type=int, default=3)
    parser.add_argument('--X_MLP_width', type=int, default=0)
    parser.add_argument('--X_MLP_layer', type=int, default=0)
    parser.add_argument('--MLP_last_fn', type=str, default='sigmoid')
    parser.add_argument('--suffstat_eps', type=float, default=1)
    parser.add_argument('--repeat_suffstat', type=int, default=1,
                       help="number of times to repeat sufficient statitic input")
    parser.add_argument('--rand_prior', type=int, default=0,
                       help="Booling for whether the MLP uses randomized prior functions")
    # uses osband's randomized prior code in ModelWithPrior
    parser.add_argument('--prior_scale', type=float, default=0)
    parser.add_argument('--bootstrap_seed', type=int, default=None)

    parser.add_argument('--postprocess_often', type=int, default=0,
                       help="Bool for whether to postprocess everytime a new model is saved (vs at the end)")
    
    # Train on only one timestep / number of observations to condition on
    parser.add_argument('--sequential_one_length', type=int, default=None)
    parser.add_argument('--weight_factor', type=float, default=1)
    parser.add_argument('--scheduler_type', type=str, default='constant')


    # Dataset Processing -------------------------------------
    parser.add_argument('--dataset_type', type=str, choices=['synthetic'])
    parser.add_argument('--sample_frac', type=float, default=1.0, 
            help='for testing purposes, subsample train and eval datasets for faster testing')
    parser.add_argument('--num_loader_obs', type=int, default=500)
    parser.add_argument('--num_loader_obs_train', type=int, default=500)

    # Context X: 
    parser.add_argument('--use_X_model', type=parse_bool, default=False)
    # Context only: if dataset uses same X across rows, so there is one X per column
    parser.add_argument('--one_X_per_column', type=parse_bool, default=False)

    # Synthetic Arguments --------------------------------------
    parser.add_argument('--Z_dim', type=int, default=1)
    parser.add_argument('--X_dim', type=int, default=1)

    parser.add_argument('--save_every', type=int, default=-1)
    parser.add_argument('--use_dataset_Y', type=parse_bool, default=False)
    # run the val loss checking and model saving in the middle of an epoch
    parser.add_argument('--check_val_every_n_rows', type=int, default=0)

    return parser

def main():
    parser = get_argparser()
    add_default_postproc_params(parser)

    # Parse Arguments, Initialize save files and logging ============================================
    config = parser.parse_args()

    if config.onelayer:
        assert config.marginal_vs_sequential == 'marginal'
    descriptive_name = f"{config.prefix}{config.marginal_vs_sequential}:epochs={config.epochs},bs={config.batch_size},lr={config.learning_rate},wd={config.weight_decay},MLP_layers={config.MLP_layer},MLP_width={config.MLP_width},weight_factor={config.weight_factor},max_obs={config.num_loader_obs},"
    if config.marginal_vs_sequential == 'sequential':
        descriptive_name += f"repeat_suffstat={config.repeat_suffstat},"

    if config.rand_prior:
        assert config.marginal_vs_sequential == 'marginal'
        descriptive_name += "rand_prior=True,"
    # this uses the new randomized prior code
    if config.prior_scale != 0:
        assert config.marginal_vs_sequential == 'marginal'
        descriptive_name += f"prior_scale={config.prior_scale},"

    if config.num_loader_obs != config.num_loader_obs_train:
        descriptive_name += f'max_obs_train={config.num_loader_obs_train},'
      
    if config.sequential_one_length is not None:
        assert config.marginal_vs_sequential == 'sequential'
        descriptive_name += f'sequential_one_length={config.sequential_one_length},'

    descriptive_name += f'Zdim={config.Z_dim},'
    if config.bootstrap_seed is not None:
        descriptive_name += f'boot_seed:{config.bootstrap_seed},'
    if config.embed_data_dir is True:
        descriptive_name += 'embed_data,'
    if config.scheduler_type != 'linear':
        descriptive_name += f'sched={config.scheduler_type},'

    if config.X_MLP_layer > 0 and config.X_MLP_width > 0:
        descriptive_name += f'X_MLP_layer={config.X_MLP_layer},X_MLP_width={config.X_MLP_width},'
    descriptive_name += f'suffstat_eps={config.suffstat_eps},'
    # previous descriptive_name ends with a comma (,)
    if config.use_dataset_Y:
        descriptive_name += 'useY,'
    descriptive_name = descriptive_name + f'seed={config.seed}'

    save_dir = config.data_dir + '/models/' + config.save_name + '/' + descriptive_name + '/'
    logging.info(descriptive_name)
    logging.info(f"Saving in {save_dir}")
    make_parent_dir(save_dir)
    torch.save(config, save_dir + '/config.pt')

    wandb.login()
    print('=========')
    print(config.save_name, config.wandb_user)
    wandb.init(project=config.save_name, entity=config.wandb_user,
            dir=save_dir,
            config=config,
            name=descriptive_name+config.data_dir.split("/")[-1])
    
    logging.info(config)
    set_seed(config.seed)
        
    if config.gpu is not None and int(config.gpu) >= 0:
        config.device = torch.device(f'cuda:{config.gpu}')
    else:
        config.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    # Loading Datasets and Making Dataset Objects =====================================================================

    loader_dict = get_loaders_synthetic_with_context(config)

        
    train_loader = loader_dict['train_loader']
    train_dataset = loader_dict['train_dataset']
    train_fixed_subset_loader = loader_dict['train_fixed_subset_loader']
    val_loader = loader_dict['val_loader']
    val_dataset = loader_dict['val_dataset']
    
    assert train_dataset[0]['Z'].shape[0] == config.Z_dim
    
    if config.extra_eval_data is not None:
        extra_eval_loader = loader_dict['extra_eval_loader']
        extra_eval_dataset = loader_dict['extra_eval_dataset']
     
    # Initialize Prediction Models ===================================================================================
    model_seed = config.seed
    if config.bootstrap_seed is not None:
        model_seed += config.bootstrap_seed
    set_seed(model_seed)

    assert config.dataset_type == "synthetic"
    model, optimizer_dict = get_model_and_optimizer_context(config)

    if hasattr(config, 'end2end') and config.end2end:
        optimizer = optimizer_dict['encoder']
    else:
        optimizer = optimizer_dict['all']
        
    logging.info(model)
    total_batches = len(train_loader) * config.epochs
    if not hasattr(config, 'scheduler_type'):
        setattr(config, 'scheduler_type', 'constant')
    scheduler = get_scheduler(config.scheduler_type, optimizer,
            num_training_steps=total_batches, num_warmup_steps=0)

    # Training Loop =================================================================
    logging.info("Begin training")
    best_loss = np.inf

    check_val_every_n_iters = 0
    if config.check_val_every_n_rows > 0:
        check_val_every_n_iters = config.check_val_every_n_rows // config.batch_size

    use_X_model = hasattr(config, 'use_X_model') and config.use_X_model
    start = time.time()
    epoch = 0
    logging.info(f"=== Epoch {epoch} ===")
    epoch_loss_unweighted = 0
    epoch_obs = 0
    epoch_loss_train = 0
    for _ in range(config.epochs):
        for i, batch in enumerate(train_loader):

            for k,v in batch.items():
                batch[k] = v.to(config.device)
            model.train()
            optimizer.zero_grad()
            
            click_mask = batch['click_length_mask']
            model_input = batch['Z']

            exact = False

            if use_X_model:
                loss_matrix = model.eval_seq(model_input,
                                         batch['X'], batch['click_obs'])
            else:
                loss_matrix = model.eval_seq(model_input,
                                         batch['click_obs'], N=None, exact=exact)
            if config.sequential_one_length is not None:
                loss_matrix = loss_matrix[:,[config.sequential_one_length]]
                click_mask = click_mask[:,[config.sequential_one_length]]

            loss_train = loss_from_loss_matrix(loss_matrix, click_mask, 
                                    how='sum_avg_per_row', weight_factor=config.weight_factor)
            loss_train.backward()
            optimizer.step()
            scheduler.step()

            if i % 50 == 0:
                logging.info(f"  iter [{i}]: training loss (weight={config.weight_factor}) {loss_train.item():5.4f} {batch['click_obs'].mean()}")
                        
            epoch_loss_train += loss_train.detach().cpu().item()

            if config.weight_factor != 1:
                # less computation for regular loss by not tracking gradients
                with torch.no_grad():
                    loss_unweighted = loss_from_loss_matrix(loss_matrix, click_mask, 
                                                 how='sum_avg_per_row')

                 # log (unweighted) loss, regardless of training objective
                if i % 50 == 0:
                    logging.info(f"  iter [{i}]: loss {loss_unweighted.item():5.4f} {batch['click_obs'].mean()}")

                epoch_loss_unweighted += loss_unweighted.detach().cpu().item()
            
            epoch_obs += len(batch['click_obs'])
    # for each epoch
        # for each batch
            if i == len(train_loader) - 1 or ( (check_val_every_n_iters > 0 ) and (i % check_val_every_n_iters == check_val_every_n_iters - 1) ):
                # end of epoch, *or* enough iters have passed
                logging.info(f'Finished epoch {epoch}; epoch loss {epoch_loss_train/epoch_obs}')
                wandb.log({'train_loss': epoch_loss_train/epoch_obs}, step=epoch)
                if config.weight_factor != 1:
                    wandb.log({'weighted_train_loss': epoch_loss_unweighted/epoch_obs}, step=epoch)


                val_start = time.time()
                val_loss_dict = get_val_loss(model, val_loader, config.device, weight_factor=1, embed_data=config.embed_data_dir, use_X_model=use_X_model) # unweighted val loss
                val_end = time.time()
                wandb.log({'epoch_val_secs': (val_end - val_start)}, step=epoch)
                train_subset_loss_dict = get_val_loss(model, train_fixed_subset_loader, config.device, weight_factor=1, embed_data=config.embed_data_dir,use_X_model=use_X_model)

                val_loss = val_loss_dict['loss']

                train_subset_loss = train_subset_loss_dict['loss']

                wandb.log({'val_loss': val_loss}, step=epoch)
                wandb.log({'train_subset_loss': train_subset_loss}, step=epoch)
                

                if config.weight_factor != 1:
                    val_loss_dict_weighted = get_val_loss(model, val_loader, config.device, val_dataset, weight_factor=config.weight_factor, embed_data=config.embed_data_dir)
                    train_subset_loss_dict_weighted = get_val_loss(model, train_fixed_subset_loader, config.device, 
                                                                       train_dataset, weight_factor=config.weight_factor, embed_data=config.embed_data_dir)
                        
                    wandb.log({'val_loss_weighted': val_loss_dict_weighted['loss']}, step=epoch)
                    wandb.log({'train_subset_loss_weighted': train_subset_loss_dict_weighted['loss']}, step=epoch)
                    logging.info(f'val_loss_weighted: {val_loss_dict_weighted["loss"]}')


                # log MSE for val and train_subset
                for name, loss_dict in [('val', val_loss_dict), ('train_subset', train_subset_loss_dict)]:
                    predicted_probs = loss_dict['theta_hats']
                    ground_truth = loss_dict['click_rates']
                    squared_error = (predicted_probs - ground_truth)**2
                    
                    # MSE on average, over all observed lengths
                    mse_any_obs = squared_error.mean()
                    wandb.log({f'{name} mse, all lengths': mse_any_obs}, step=epoch)
                    loss_dict['mse_any_obs'] = mse_any_obs

                    # MSE on average, after 0 observations
                    mse_0_obs = squared_error[:,0].mean()
                    wandb.log({f'{name} mse, after 0 obs': mse_0_obs}, step=epoch)
                    loss_dict['mse_0_obs'] = mse_0_obs

                    # MSE on average, after max number of observations
                    mse_max_obs = squared_error[:,-1].mean()
                    wandb.log({f'{name} mse, after max obs': mse_max_obs}, step=epoch)
                    loss_dict['mse_max_obs'] = mse_max_obs

                    loss_per_t = loss_dict['loss_per_t']
                    mse_per_t = squared_error.mean(axis=0)
                    for t in [0, 1, 2, 3, 4, 5, 10, 25, 100]:
                        if t < config.num_loader_obs and t < len(loss_per_t):
                            wandb.log({f'{name}_loss, after {str(t)} obs': loss_per_t[t]}, step=epoch)
                            wandb.log({f'{name}_mse, after {str(t)} obs': mse_per_t[t]}, step=epoch)

                logging.info(f'val_loss: {val_loss}')
                
                save_dict = {
                    'state_dict':model.state_dict(),
                    'optimizer': optimizer_dict,
                    #'scheduler': scheduler.state_dict(),
                    'epoch': epoch,
                    'val_loss_dict': val_loss_dict,
                    'train_subset_loss_dict': train_subset_loss_dict,
                    'config': config,
                }
                if val_loss < best_loss:
                    best_loss = val_loss
                    
                    if config.extra_eval_data is not None:
                        extra_eval_loss_dict = get_val_loss(model, extra_eval_loader, config.device, extra_eval_dataset, embed_data=config.embed_data_dir, use_X_model=use_X_model)
                        save_dict['extra_eval_loss_dict'] = extra_eval_loss_dict

                    torch.save(save_dict, save_dir + '/best_loss.pt')
                    if config.postprocess_often:
                        setattr(config, 'run_dir', save_dir)
                        do_postprocessing(config)
     
                # save latest
                if config.save_every > 0 and epoch % config.save_every == 0:
                    torch.save(save_dict, save_dir + f'/epoch_{epoch}.pt') 
                torch.save(save_dict, save_dir + '/latest.pt') 
                end = time.time()
                wandb.log({'epoch_train_secs': (end - start)}, step=epoch)
                epoch += 1 
                start = time.time()
                logging.info(f"=== Epoch {epoch} ===")
                epoch_loss_unweighted = 0
                epoch_obs = 0
                epoch_loss_train = 0

    # do postprocessing
    setattr(config, 'run_dir', save_dir)
    do_postprocessing(config)

if __name__ == '__main__':
    main()
