import sys
import os
import argparse
from datetime import datetime
import numpy as np
import torch
from data import data_regression
import utils
from networks import full_networks
from models import niwmeta

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='sineline', help='dataset')
    parser.add_argument('--k', type=int, default=5, help='support set size')
    parser.add_argument('--kq', type=int, default=45, help='query set size')
    parser.add_argument('--feature', type=str, default='fcnet_for_sineline', help='feature network')
    parser.add_argument('--head', type=str, default='ridge_0.1', help='head after feature')
    parser.add_argument('--num_epochs', type=int, default=200, help='training epochs')
    parser.add_argument('--num_episodes_per_epoch', type=int, default=10000, help='this many episodes constitutes a single training epoch')
    parser.add_argument('--minibatch', type=int, default=20, help='#episodes to update meta-parameters')
    parser.add_argument('--lr', type=float, default=1e-3, help='lr for meta model')
    parser.add_argument('--num_episodes_test', type=int, default=10000, help='number of episodes for test')
    parser.add_argument('--num_workers', type=int, default=2, help='number of workers used in data loader')    
    parser.add_argument('--report_freq', type=int, default=500, help='report training progress every this episodes')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument('--model', type=str, default='niwmeta', help='model choice')
    parser.add_argument('--spiky', type=int, default=0, help='use spiky q_i(th_i)=N(th_i;m_i,V_i) or not; ie, V_i=eps^2*I as const or optimizable')
    parser.add_argument('--gam0_init', type=float, default=1e4, help='initial Gam0 values (the same for all dims)')
    parser.add_argument('--gam0_max', type=float, default=1e8, help='max value for Gam0')
    parser.add_argument('--sgld_steps', type=int, default=5, help='# of SGLD steps for fitting with episode data')
    parser.add_argument('--sgld_burnin', type=int, default=2, help='this initial # of SGLD steps regarded as burnin in SGLD')
    parser.add_argument('--sgld_alp', type=float, default=1e-3, help='lr alpha in SGLD steps')
    parser.add_argument('--sgld_ai_max', type=float, default=1e3, help='max value for A_i in SGLD variance estimation')
    parser.add_argument('--query_only_for_loss', type=int, default=1, help='whether to use Q only for loss, or S+Q')
    parser.add_argument('--steps_test', type=int, default=2, help='# of VI SGD steps for meta test')
    parser.add_argument('--lr_test', type=float, default=1e-4, help='lr in VI SGD for meta test')
    parser.add_argument('--nsamps_test', type=int, default=5, help='# of MC samples at test time')
    args = parser.parse_args()
    return args

class Runner:
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        full_net, derived_head = full_networks.get_full_network(args.feature, args.head, args)
        kwargs = {'spiky': args.spiky, 'gam0_init': args.gam0_init, 'gam0_max': args.gam0_max, 'sgld_steps': args.sgld_steps, 'sgld_burnin': args.sgld_burnin, 'sgld_alp': args.sgld_alp, 'sgld_ai_max': args.sgld_ai_max,}
        self.model = niwmeta.NIWMeta(full_net, derived_head, **kwargs).to(args.device)
        self.model.initialize_helper(args.device)
        self.lossfun = torch.nn.MSELoss()

    def train(self, train_dataloader, val_dataloader, resume_ckpt=None):
        args = self.args
        logger = self.logger
        self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=args.lr)
        epoch = -1
        for epoch_id in range(epoch+1, args.num_epochs):
            loss_v = 0.  # loss avg'ed over a batch of episodes
            #loss_monitor = 0.  # running avg maintained per epoch
            for eps_count, eps_data in enumerate(train_dataloader):
                if (eps_count >= args.num_episodes_per_epoch):
                    break
                split_data = data_regression.split_support_query(eps_data=eps_data, k_shot=args.k)
                x_t = split_data['x_t'].to(args.device)
                y_t = split_data['y_t'].to(args.device)
                x_v = split_data['x_v'].to(args.device)
                y_v = split_data['y_v'].to(args.device)
                loss, _, loss_no_reg = self.model(x_t, y_t, x_v, y_v, lossfun=self.lossfun, query_only_for_loss=args.query_only_for_loss)
                if not isinstance(loss, list):
                    loss = [loss,]
                loss_v = loss_v +  loss[0] / args.minibatch
                if torch.isnan(input=loss_v):
                    logger.info('Loss in NaN.')
                    raise ValueError("Loss is NaN.")
                if eps_count == 0:
                    loss_monitor = [0.,] * len(loss)
                for _ in range(len(loss)):
                    loss_monitor[_] += loss[_].item()
                if (eps_count + 1) % args.minibatch == 0:
                    self.optimizer.zero_grad()
                    loss_v.backward()
                    self.optimizer.step()
                    loss_v = 0.
                if (eps_count + 1) % args.report_freq == 0:
                    loss_str = ', '.join(["{:.4f}".format(_/(eps_count+1)) for _ in loss_monitor])
                    logger.info('Epoch %03d episode %05d: train loss (running avg): %s' % (epoch_id, eps_count, loss_str))
            if val_dataloader is not None:
                logger.info('Evaluation on validation...')
                loss_temp, accuracy_temp = self.evaluate(num_eps = args.num_episodes_test, eps_dataloader = val_dataloader)
                logger.info('**** Epoch %03d: val loss: %.4f, accuracy: %.4f ****' % (epoch_id, np.mean(loss_temp), np.mean(accuracy_temp)))

    def evaluate(self, num_eps, eps_dataloader):
        args = self.args
        losses = [None] * num_eps
        accuracies = [None] * num_eps
        for eps_id, eps_data in enumerate(eps_dataloader):
            if eps_id >= num_eps:
                break
            split_data = data_regression.split_support_query(eps_data=eps_data, k_shot=args.k)
            x_t = split_data['x_t'].to(args.device)
            y_t = split_data['y_t'].to(args.device)
            x_v = split_data['x_v'].to(args.device)
            y_v = split_data['y_v'].to(args.device)
            loss, logits, logits_all = self.model.evaluate(x_t, y_t, x_v, y_v, lossfun=self.lossfun, steps=self.args.steps_test, lr=self.args.lr_test, nsamps=self.args.nsamps_test)
            accuracy = (logits.argmax(dim=1) == y_v).float().mean().item()  # meaningless for regression
            losses[eps_id] = loss.item()
            accuracies[eps_id] = accuracy * 100
            if (eps_id + 1) % args.report_freq == 0:
                logger.info('    (Up to episode %s) Validation loss avg: %.4f' % (eps_id, np.mean(losses[:eps_id])))
        return losses, accuracies

if __name__ == "__main__":
    args = get_args()
    if args.dataset == 'sineline':
        args.output_dim = 1
    else:
        raise NotImplementedError
    args.head_str = args.head
    if args.head.startswith('ridge'):
        args.head = 'ridge'
        args.ridge_reg = float(args.head_str.split('_')[1])
    args.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(args.seed)
    args.logdir = os.path.join('runs')
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir, exist_ok=True)
    logger = utils.Logger.get(os.path.join(args.logdir, "log.txt"))
    cmd = " ".join(sys.argv)
    logger.info(f"Command :: {cmd}\n")
    if args.dataset == 'sineline':
        train_dataloader, val_dataloader = data_regression.get_dataloaders_sine_line(nsamples=args.k+args.kq)
    else:
        raise NotImplementedError
    runner = Runner(args, logger)
    runner.train(train_dataloader, val_dataloader)
