import torch
import argparse
from model import SetTransformer
from itertools import chain
import torch.nn.functional as F
import numpy as np
from scipy.special import gamma
import sys
import random

torch.set_num_threads(1)

def _load_opt_fns(args, models, _ff):
    non_pool_params = list(chain.from_iterable([list(model.parameters()) for model in models]))
    pool_params_p = []
    pool_params_q = []
    
    optimizer, optimizer_pool = None, None
    
    if args.opt_fn == 'rmsprop':
        _ff.write(f"rmsprop with lr ({args.lr_pool}, {args.lr})\n")
        optimizer = torch.optim.RMSprop(non_pool_params, lr=args.lr, weight_decay=args.weight_decay)
        if len(pool_params_p) > 0:
            optimizer_pool = torch.optim.RMSprop([{'params': pool_params_p},
                                                  {'params': pool_params_q, 'lr': args.lr}],
                                                 lr=args.lr_pool)
    elif args.opt_fn == 'adamgan':
        _ff.write(f"adam_gan with lr ({args.lr_pool}, {args.lr})\n")
        optimizer = torch.optim.Adam(non_pool_params, lr=args.lr, betas=(0.5, 0.999), weight_decay=args.weight_decay) #, weight_decay=5e-4)
        if len(pool_params_p) > 0:
            optimizer_pool = torch.optim.Adam([{'params': pool_params_p},
                                               {'params': pool_params_q, 'lr': args.lr}],
                                              lr=args.lr_pool, betas=(0.5, 0.999))
    else:
        _ff.write(f"adam with lr ({args.lr_pool}, {args.lr})\n")
        optimizer = torch.optim.Adam(non_pool_params, lr=args.lr, weight_decay=args.weight_decay) #, weight_decay=5e-4)
        if len(pool_params_p) > 0:
            optimizer_pool = torch.optim.Adam([{'params': pool_params_p},
                                               {'params': pool_params_q, 'lr': args.lr}],
                                              lr=args.lr_pool)
    return optimizer, optimizer_pool

def _gen_data(args, n_train, n_val, n_test):
    masks = torch.tril(torch.ones(101, 101).long(), diagonal=-1).detach()
    if args.task in ['closed_form', 'map_mu']:
        mu_0, sigma_0 = 0, (1 ** 2)
        sigma = (1 ** 2)
        train_mu, val_mu, test_mu = (torch.randn(n_train) * (sigma_0 ** 0.5)) + mu_0, (torch.randn(n_val) * (sigma_0 ** 0.5)) + mu_0, (torch.randn(n_test) * (sigma_0 ** 0.5)) + mu_0
        train_sigma, val_sigma, test_sigma = None, None, None
        train_len, val_len, test_len = (20. + 20. * torch.rand(n_train)).long(), (20. + 20. * torch.rand(n_val)).long(), (50. + 50. * torch.rand(n_test)).long()
        train_seq_len, val_seq_len, test_seq_len = train_len.max(), val_len.max(), test_len.max()
        train_mask, val_mask, test_mask = masks[train_len][:, :train_seq_len], masks[val_len][:, :val_seq_len], masks[test_len][:, :test_seq_len]
        train_data = ((torch.randn(n_train, train_seq_len) * (sigma ** 0.5) + train_mu.unsqueeze(-1)) * train_mask).detach()
        val_data   = ((torch.randn(n_val,     val_seq_len) * (sigma ** 0.5) +   val_mu.unsqueeze(-1)) *   val_mask).detach()
        test_data  = ((torch.randn(n_test,   test_seq_len) * (sigma ** 0.5) +  test_mu.unsqueeze(-1)) *  test_mask).detach()
        args.mu_0, args.sigma_0, args.sigma = mu_0, sigma_0, sigma
    else:
        alpha, beta = 1., 15.
        mu = 5
        dist = torch.distributions.gamma.Gamma(torch.tensor([alpha]), torch.tensor([beta]))
        train_mu, val_mu, test_mu = None, None, None
        train_sigma, val_sigma, test_sigma = torch.FloatTensor([dist.sample().item() for _ in range(n_train)]), torch.FloatTensor([dist.sample().item() for _ in range(n_val)]), torch.FloatTensor([dist.sample().item() for _ in range(n_test)])
        train_len, val_len, test_len = (20. + 20. * torch.rand(n_train)).long(), (20. + 20. * torch.rand(n_val)).long(), (50. + 50. * torch.rand(n_test)).long()
        train_seq_len, val_seq_len, test_seq_len = train_len.max(), val_len.max(), test_len.max()
        train_mask, val_mask, test_mask = masks[train_len][:, :train_seq_len], masks[val_len][:, :val_seq_len], masks[test_len][:, :test_seq_len]
        train_data = ((torch.randn(n_train, train_seq_len) * (train_sigma.unsqueeze(-1) ** 0.5) + mu) * train_mask).detach()
        val_data   = ((torch.randn(n_val,     val_seq_len) * (  val_sigma.unsqueeze(-1) ** 0.5) + mu) *   val_mask).detach()
        test_data  = ((torch.randn(n_test,   test_seq_len) * ( test_sigma.unsqueeze(-1) ** 0.5) + mu) *  test_mask).detach()
        args.alpha, args.beta, args.mu = alpha, beta, mu
        
    return {'train': (train_data, train_len, train_mask, train_mu, train_sigma),
            'val': (val_data, val_len, val_mask, val_mu, val_sigma),
            'test': (test_data, test_len, test_mask, test_mu, test_sigma)}

def _load_gts_mu_sigma(args, inputs, input_masks, input_lengths):
    gt_sigma = (1. / ((1. / args.sigma_0) + (input_lengths / args.sigma))).detach()
    gt_mu = (gt_sigma * ((args.mu_0 / args.sigma_0) + (inputs.squeeze(-1).sum(-1) / args.sigma))).detach()    
    return gt_mu, gt_sigma

def _load_gts_sigma(args, inputs, input_masks, input_lengths):
    gt_alpha = args.alpha + (input_lengths / 2.)
    gt_beta = args.beta + 0.5 * (input_masks.squeeze(-1) * ((inputs.squeeze(-1) - args.mu) ** 2)).sum(-1)
    gt_sigma = (gt_beta / (gt_alpha + 1.)).detach()
    return None, gt_sigma

def _compute_loss_closed_form(args, inputs, input_masks, input_lengths, preds_mu, preds_sigma, gt_mu, gt_sigma):
    running_loss_mu = torch.sum((gt_mu - preds_mu) ** 2)
    running_mape_mu = torch.sum(torch.abs(gt_mu - preds_mu) / torch.abs(gt_mu))
    running_loss_sigma = torch.sum((gt_sigma - preds_sigma) ** 2)
    running_mape_sigma = torch.sum(torch.abs(gt_sigma - preds_sigma) / torch.abs(gt_sigma))
    return running_loss_mu, running_mape_mu, running_loss_sigma, running_mape_sigma

def _compute_loss_map_mu(args, inputs, input_masks, input_lengths, preds_mu, preds_sigma, gt_mu, gt_sigma):
    _running_loss_mu_first = np.log(np.sqrt(2 * 3.14159265358979 * args.sigma_0)) + 0.5 * (((preds_mu - args.mu_0) ** 2) / args.sigma_0)
    _running_loss_mu_second = torch.sum(input_masks.detach().squeeze(-1), dim=-1) * np.log(np.sqrt(2 * 3.14159265358979 * args.sigma))
    _running_loss_mu_third = 0.5 * torch.sum((((inputs.detach().squeeze(-1) - preds_mu.unsqueeze(-1)) ** 2) / args.sigma) * input_masks.detach().squeeze(-1), dim=-1)
    
    running_loss_mu = torch.sum(_running_loss_mu_first + _running_loss_mu_second + _running_loss_mu_third)
    running_mape_mu = torch.sum(torch.abs(gt_mu - preds_mu) / torch.abs(gt_mu))
    
    return running_loss_mu, running_mape_mu, None, None            

def _compute_loss_map_sigma(args, inputs, input_masks, input_lengths, preds_mu, preds_sigma, gt_mu, gt_sigma):
    _loss_sigma_first = -args.alpha * np.log(args.beta) + np.log(gamma(args.alpha)) + (args.beta / preds_sigma) + (args.alpha + 1) * torch.log(torch.clamp(preds_sigma, min=1e-12))
    _loss_sigma_second = torch.sum(input_masks.detach().squeeze(-1), dim=-1) * (0.5 * torch.log(2 * 3.14159265358979 * preds_sigma))
    _loss_sigma_third = 0.5 * torch.sum((((inputs.detach().squeeze(-1) - args.mu) ** 2) / preds_sigma.unsqueeze(-1)) * input_masks.detach().squeeze(-1), dim=-1)
    
    running_loss_sigma = torch.sum(_loss_sigma_first + _loss_sigma_second + _loss_sigma_third)
    running_mape_sigma = torch.sum(torch.abs(gt_sigma - preds_sigma) / torch.abs(gt_sigma))
    
    return None, None, running_loss_sigma, running_mape_sigma

def evaluate(args, models, eval_data, n_eval, gt_fn, loss_fn):
    # load settings
    model_mu, model_sigma = models
    model_mu.eval()
    model_sigma.eval()
    data, lengths, mask, mu, sigma = eval_data
    
    # initialize values for computing loss
    loss_mu, mape_mu, loss_sigma, mape_sigma = 0., 0., 0., 0.
    
    for ii in range(0, n_eval, args.batch_size):
        # load mini-batch and compute ground-truth value
        inputs = data[ii:ii+args.batch_size].unsqueeze(-1).detach().to(args.device)
        input_masks = mask[ii:ii+args.batch_size].unsqueeze(-1).detach().to(args.device)
        input_lengths = lengths[ii:ii+args.batch_size].detach().to(args.device)
        gt_mu, gt_sigma = gt_fn(args, inputs, input_masks, input_lengths)
        
        # Eval
        preds_sigma = model_sigma(torch.cat((inputs, torch.ones_like(inputs)), dim=-1), input_masks)
        preds_mu    =    model_mu(torch.cat((inputs, torch.ones_like(inputs)), dim=-1), input_masks)
        loss = loss_fn(args, inputs, input_masks, input_lengths, preds_mu, preds_sigma, gt_mu, gt_sigma)
            
        # Update eval loss
        _loss_mu, _mape_mu, _loss_sigma, _mape_sigma = loss
        
        loss_mu += _loss_mu if _loss_mu is not None else -args.batch_size
        mape_mu += _mape_mu if _mape_mu is not None else -args.batch_size
        loss_sigma += _loss_sigma if _loss_sigma is not None else -args.batch_size
        mape_sigma += _mape_sigma if _mape_sigma is not None else -args.batch_size
        
    return loss_mu / n_eval, mape_mu / n_eval, loss_sigma / n_eval, mape_sigma / n_eval

def train(args, models, optimizers, train_data, n_train, gt_fn, loss_fn):
    # load settings
    model_mu, model_sigma = models
    optimizer, optimizer_pool = optimizers
    model_mu.train()
    model_sigma.train()
    indices = torch.randperm(n_train)
    data, lengths, mask, mu, sigma = train_data
    
    # initialize values for computing loss
    loss_mu, mape_mu, loss_sigma, mape_sigma = 0., 0., 0., 0.
    
    for ii in range(0, n_train, args.batch_size):
        # load mini-batch and compute ground-truth value
        inputs = data[indices[ii:ii+args.batch_size]].unsqueeze(-1).detach().to(args.device)
        input_masks = mask[indices[ii:ii+args.batch_size]].unsqueeze(-1).detach().to(args.device)
        input_lengths = lengths[indices[ii:ii+args.batch_size]].detach().to(args.device)
        gt_mu, gt_sigma = gt_fn(args, inputs, input_masks, input_lengths)
        
        # Train
        optimizer.zero_grad()
        if optimizer_pool is not None: optimizer_pool.zero_grad()
        preds_sigma = model_sigma(torch.cat((inputs, torch.ones_like(inputs)), dim=-1), input_masks)
        preds_mu    =    model_mu(torch.cat((inputs, torch.ones_like(inputs)), dim=-1), input_masks)
        
        loss = loss_fn(args, inputs, input_masks, input_lengths, preds_mu, preds_sigma, gt_mu, gt_sigma)
        
        _loss_mu, _mape_mu, _loss_sigma, _mape_sigma = loss
        if _loss_mu is not None:
            _loss_mu.backward()
        if _loss_sigma is not None:
            _loss_sigma.backward()
        
        with torch.no_grad():
            for model in models:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_limit)
        optimizer.step()
        if optimizer_pool is not None: optimizer_pool.step()
            
        # Update training loss    
        loss_mu += _loss_mu.item() if _loss_mu is not None else -args.batch_size
        mape_mu += _mape_mu.item() if _mape_mu is not None else -args.batch_size
        loss_sigma += _loss_sigma.item() if _loss_sigma is not None else -args.batch_size
        mape_sigma += _mape_sigma.item() if _mape_sigma is not None else -args.batch_size
        
    return loss_mu / n_train, mape_mu / n_train, loss_sigma / n_train, mape_sigma / n_train
    
def main(args):
    args.device = device = ('cuda:' + args.gpu) if torch.cuda.is_available() else 'cpu'
    
    # Define model
    input_dim = 2
        
    model_mu = SetTransformer(input_dim, 1, 1, dim_hidden=args.n_hidden, block_type=args.gtype).to(device)
    model_sigma = SetTransformer(input_dim, 1, 1, dim_hidden=args.n_hidden, block_type=args.gtype).to(device)
    model_mu.train()
    model_sigma.train()
    
    # Setting (from arguments)
    task = args.task
    batch_size = args.batch_size
    norm_limit = args.norm_limit
    log_file_name = f"./logs/{args.task}_{args.pooling_type}_{args.n_layers}_{args.opt_fn}_{args.lr_pool}_{args.lr}_{args.gtype}_{args.seed}.txt"
    checkpoint_mu_name = f"./checkpoints/{args.task}_{args.pooling_type}_{args.n_layers}_{args.opt_fn}_{args.lr_pool}_{args.lr}_{args.gtype}_{args.seed}_mu.pt"
    checkpoint_sigma_name = f"./checkpoints/{args.task}_{args.pooling_type}_{args.n_layers}_{args.opt_fn}_{args.lr_pool}_{args.lr}_{args.gtype}_{args.seed}_sigma.pt"
    _ff = open(log_file_name, "w")
    # _ff = sys.stdout
    
    optimizer, optimizer_pool = _load_opt_fns(args, (model_mu, model_sigma), _ff)
    
    n_params = sum(p.numel() for p in model_mu.parameters()) + sum(p.numel() for p in model_sigma.parameters())
    n_nonpool, n_pool = 0, 0
    
    for p, _ in model_mu.named_parameters():
        _ff.write(f"{p}\n")
        
    for group in optimizer.param_groups:
        for p in group['params']:
            n_nonpool += p.numel()
    if optimizer_pool is not None:
        for group in optimizer_pool.param_groups:
            for p in group['params']:
                n_pool += p.numel()
    _ff.write(f"# of parameters: {n_params}, (non-pool {n_nonpool}, pool {n_pool})\n")
    _ff.flush()
    
    n_train, n_val, n_test = 4000, 500, 500
    
    # Generate Train/Val/Test data
    data = _gen_data(args, n_train, n_val, n_test)
    
    gt_fns = {'closed_form': _load_gts_mu_sigma, 'map_mu': _load_gts_mu_sigma, 'map_sigma': _load_gts_sigma}
    loss_fns = {'closed_form': _compute_loss_closed_form, 'map_mu': _compute_loss_map_mu, 'map_sigma': _compute_loss_map_sigma}
    gt_fn, loss_fn = gt_fns[task], loss_fns[task]
    
    best_mu = [1e10, 1e10, 1e10]
    best_sigma = [1e10, 1e10, 1e10]
    for i in range(args.n_epochs):
        train_loss_mu, train_mape_mu, train_loss_sigma, train_mape_sigma = train(args, (model_mu, model_sigma), (optimizer, optimizer_pool), data['train'], n_train, gt_fn, loss_fn)
        with torch.no_grad():
            val_loss_mu,  val_mape_mu,  val_loss_sigma,  val_mape_sigma  = evaluate(args, (model_mu, model_sigma),  data['val'],  n_val, gt_fn, loss_fn)
            test_loss_mu, test_mape_mu, test_loss_sigma, test_mape_sigma = evaluate(args, (model_mu, model_sigma), data['test'], n_test, gt_fn, loss_fn)
        
        if best_mu[0] > val_mape_mu:
            best_mu[0] = val_mape_mu
            best_mu[1] = test_mape_mu
            torch.save(model_mu.state_dict(), checkpoint_mu_name)
        if best_mu[2] > test_mape_mu:
            best_mu[2] = test_mape_mu
        
        if best_sigma[0] > val_mape_sigma:
            best_sigma[0] = val_mape_sigma
            best_sigma[1] = test_mape_sigma
            torch.save(model_sigma.state_dict(), checkpoint_sigma_name)
        if best_sigma[2] > test_mape_sigma:
            best_sigma[2] = test_mape_sigma
        
        _ff.write(f"Epoch #{i+1} | train_loss mu {train_mape_mu} sigma^2 {train_mape_sigma} | val_loss mu {val_mape_mu} sigma^2 {val_mape_sigma} | test_loss mu {test_mape_mu} sigma^2 {test_mape_sigma}\n")
        _ff.flush()
        
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Posterior')
    parser.add_argument("--lr", type=float, default=3e-3,
                        help="learning rate for the other parameters")
    parser.add_argument("--lr-pool", type=float, default=3e-2,
                        help="learning rate for p")
    parser.add_argument("--n-epochs", type=int, default=300,
                        help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=64,
                        help="number of hidden units")
    parser.add_argument("--n-layers", type=int, default=1,
                        help="number of hidden layers before pooling")
    parser.add_argument("--weight-decay", type=float, default=0,
                        help="Weight for L2 loss")
    parser.add_argument("--pooling-type", type=str, default="transformer",
                        help="Pooling type: transformer")
    parser.add_argument("--batch-size", type=int, default=50,
                        help="batch_size")
    parser.add_argument("--norm-limit", type=float, default=1e4)
    parser.add_argument("--opt-fn", type=str, default="rmsprop",
                        help="Function type: rmsprop/adam/adamgan")
    parser.add_argument("--task", type=str, default="closed_form",
                        help="task type: closed_form/map_mu/map_sigma")
    parser.add_argument("--gpu", type=str, default="0")
    parser.add_argument("--gtype", type=int, default=0,
                        help="0: 2 ISAB blocks, 1: 2 SAB blocks")
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()
    print(args)
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_deterministic(True)
    
    main(args)
