import os
import sys
import time
import argparse
import importlib

import torch

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from npf.data.gp import GPFiniteSampler, GPSampler
from npf.utils.log import RunningAverage

from npf.utils.launch import launch

from nxcl.config import save_config
from nxcl.rich.progress import Progress


def setup_argparse(parser: argparse.ArgumentParser):
    # Data
    parser.add_argument('--max_num_points', type=int, default=50)
    parser.add_argument('--min_num_points', type=int, default=5)

    # Train
    parser.add_argument('--train_seed', type=int, default=0)
    parser.add_argument('--train_dim_list', nargs='+', type=int, required=True)
    parser.add_argument('--train_kernel_list', nargs='+', type=str, default=['rbf','matern'])
    parser.add_argument('--train_batch_size', type=int, default=16)
    parser.add_argument('--train_num_samples', type=int, default=4)
    parser.add_argument('--train_num_bs', type=int, default=-1)  # -1 for infinite, otherwise finite
    parser.add_argument('--lr', type=float, default=5e-4)
    parser.add_argument('--num_steps', type=int, default=100000)
    parser.add_argument('--print_freq', type=int, default=200)
    parser.add_argument('--eval_freq', type=int, default=5000)
    parser.add_argument('--save_freq', type=int, default=1000)
    parser.add_argument('--weight_decay', type=float, default=0)

    # Eval
    parser.add_argument('--eval_seed', type=int, default=0)
    parser.add_argument('--eval_num_batches', type=int, default=3000)
    parser.add_argument('--eval_batch_size', type=int, default=16)
    parser.add_argument('--eval_num_samples', type=int, default=50)

    # OOD settings
    parser.add_argument('--eval_kernel', type=str, default='rbf')
    parser.add_argument('--t_noise', type=float, default=None)


def build_model(cfg):
    model_name = cfg.model.name
    if model_name in ["np", "anp", "cnp", "canp", "bnp", "banp", "tnp", "tnpa", "tnpd", "tnpnd"]:
        raise ValueError(f"{model_name} is not supported")

    try:
        module = importlib.import_module(f"npf.models.{model_name}")
        model_cls = getattr(module, model_name.upper())
    except Exception as e:
        raise ValueError(f'Invalid model {model_name}')

    model = model_cls(**{k: v for k, v in cfg.model.items() if k != 'name'})
    model.cuda()
    return model


def train(args, cfg, logger, save_dir, link_output_dir):
    model = build_model(cfg)
    model_name = model.__class__.__name__.lower()

    torch.manual_seed(args.train_seed)
    torch.cuda.manual_seed(args.train_seed)

    data_sub_path = "infinite" if args.train_num_bs == -1 else f"finite-nb{args.train_num_bs}-bs{args.train_batch_size}"
    exp_sub_path = os.path.join(os.path.basename(__file__.split(".")[0]), model_name, data_sub_path)
    exp_path = link_output_dir(exp_sub_path)
    logger.info(f"Experiment path: \"{exp_path}\"")
    line = ' '.join(sys.argv)
    logger.info(f"code: {line}")
    if args.train_num_bs == -1:
        train_sampler = []
        for i in args.train_dim_list:
            for j in args.train_kernel_list:
                train_sampler.append(GPSampler(
                batch_size=args.train_batch_size, max_num_points=i **2 * args.max_num_points, min_num_points=i**2*args.min_num_points,
                seed=args.train_seed, t_noise=args.t_noise, kernel=j,
                x_dim=i,
            ))
    else:
        train_sampler = []
        for i in args.train_dim_list:
            for j in args.train_kernel_list:
                train_sampler.append(GPFiniteSampler(
                save_dir=f"data/gp/{i}d/train-{j}", num_batches=args.train_num_bs,
                batch_size=args.train_batch_size, max_num_points=i **2 * args.max_num_points, min_num_points=i**2*args.min_num_points,
                seed=args.train_seed, t_noise=args.t_noise, kernel=j, loop=True,
                x_dim=i,
            ))

    eval_sampler_1d = GPFiniteSampler(
        save_dir=f"data/gp/1d/eval-{args.eval_kernel}", num_batches=args.eval_num_batches,
        batch_size=args.eval_batch_size, max_num_points=1 * args.max_num_points, min_num_points=1*args.min_num_points,
        seed=args.eval_seed, t_noise=args.t_noise, kernel=args.eval_kernel,
        x_dim=1,
    )
    eval_sampler_2d = GPFiniteSampler(
        save_dir=f"data/gp/2d/eval-{args.eval_kernel}", num_batches=args.eval_num_batches,
        batch_size=args.eval_batch_size, max_num_points=4 * args.max_num_points, min_num_points=4*args.min_num_points,
        seed=args.eval_seed, t_noise=args.t_noise, kernel=args.eval_kernel,
        x_dim=2,
    )
    eval_sampler_3d = GPFiniteSampler(
        save_dir=f"data/gp/3d/eval-{args.eval_kernel}", num_batches=args.eval_num_batches,
        batch_size=args.eval_batch_size, max_num_points=9 * args.max_num_points, min_num_points=9*args.min_num_points,
        seed=args.eval_seed, t_noise=args.t_noise, kernel=args.eval_kernel,
        x_dim=3,
    )
    eval_sampler_4d = GPFiniteSampler(
        save_dir=f"data/gp/4d/eval-{args.eval_kernel}", num_batches=args.eval_num_batches,
        batch_size=args.eval_batch_size, max_num_points=16 * args.max_num_points, min_num_points=16*args.min_num_points,
        seed=args.eval_seed, t_noise=args.t_noise, kernel=args.eval_kernel,
        x_dim=4,
    )
    eval_sampler_5d = GPFiniteSampler(
        save_dir=f"data/gp/5d/eval-{args.eval_kernel}", num_batches=args.eval_num_batches,
        batch_size=args.eval_batch_size, max_num_points=5**2 * args.max_num_points, min_num_points=5**2*args.min_num_points,
        seed=args.eval_seed, t_noise=args.t_noise, kernel=args.eval_kernel,
        x_dim=5,
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_steps)

    if args.resume:
        ckpt = torch.load(save_dir / "checkpoint.pt")
        model.load_state_dict(ckpt.model)
        optimizer.load_state_dict(ckpt.optimizer)
        scheduler.load_state_dict(ckpt.scheduler)
        start_step = ckpt.step
    else:
        start_step = 1
        logger.info(f'Total number of parameters: {sum(p.numel() for p in model.parameters())}\n')

    ravg = RunningAverage()
    def samplers():
        sampler_list = []
        for i in range(len(args.train_dim_list)*len(args.train_kernel_list)):
            sampler_list.append(iter(train_sampler[i]))
        # sampler_2d = iter(train_sampler_2d)
        # sampler_3d = iter(train_sampler_3d)
        while True:
            for i in range(len(args.train_dim_list)*len(args.train_kernel_list)):
                yield next(sampler_list[i])

    sample_iter = iter(samplers())


    with Progress(speed_estimate_period=300, disable=args.no_progress) as p:
        for step in p.trange(start_step, args.num_steps+1, description=f"{model_name.upper()}", remove=False):
            model.train()
            optimizer.zero_grad()

            batch = next(sample_iter)
            if model_name in ["anp", "banp", "danp", "mpanp"]:
                outs = model(batch, num_samples=args.train_num_samples)
            else:
                outs = model(batch)

            outs.loss.backward()
            optimizer.step()
            scheduler.step()

            for key, val in outs.items():
                ravg.update(key, val)

            if step % args.print_freq == 0 or step == args.num_steps:
                logger.info(f"step {step} lr {optimizer.param_groups[0]['lr']:.3e} [train_loss] {ravg.info()}")
                ravg.reset()

            if step % args.eval_freq == 0 or step == args.num_steps:
                logger.info("1d: " + eval(args, model, eval_sampler_1d, p))
                logger.info("2d: " + eval(args, model, eval_sampler_2d, p))
                logger.info("3d: " + eval(args, model, eval_sampler_3d, p))
                logger.info("4d: " + eval(args, model, eval_sampler_4d, p))
                logger.info("5d: " + eval(args, model, eval_sampler_5d, p) + "\n")

            if step % args.save_freq == 0 or step == args.num_steps:
                torch.save(
                    {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "step": step + 1,
                    },
                    save_dir / "checkpoint.pt",
                )


@torch.inference_mode()
def eval(args, model, eval_sampler, progress):
    model.eval()
    model_name = model.__class__.__name__.lower()

    ravg = RunningAverage()

    for batch in progress.track(eval_sampler, description="Eval", remove=True):
        if model_name in ["anp", "banp", "danp", "mpanp"]:
            outs = model(batch, args.eval_num_samples)
        else:
            outs = model(batch)

        for key, val in outs.items():
            ravg.update(key, val)

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())

    line = f'{args.eval_kernel} '
    if args.t_noise is not None:
        line += f'tn {args.t_noise} '
    line += ravg.info()
    return line


if __name__ == "__main__":
    code = launch(
        train,
        setup_argparse,
        aliases={},
    )
    exit(code)
