import os
import sys
import time
import argparse
import importlib

import torch

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from npf.data.gp import GPFiniteSampler, GPSampler
from npf.utils.log import RunningAverage

from npf.utils.launch import launch

from nxcl.config import save_config
from nxcl.rich.progress import Progress


def setup_argparse(parser: argparse.ArgumentParser):
    # Model
    parser.add_argument('--x_dim', type=int, default=1)
    parser.add_argument('-c', '--checkpoint', type=str, required=True)
    parser.add_argument('--freeze', action='store_true')

    # Data
    parser.add_argument('--max_num_points', type=int, default=50)
    parser.add_argument('--min_num_points', type=int, default=5)

    # Train
    parser.add_argument('--train_seed', type=int, default=0)
    parser.add_argument('--train_batch_size', type=int, default=16)
    parser.add_argument('--train_num_samples', type=int, default=4)
    parser.add_argument('--train_num_bs', type=int, default=-1)  # -1 for infinite, otherwise finite
    parser.add_argument('--lr', type=float, default=5e-4)
    parser.add_argument('--num_steps', type=int, default=100000)
    parser.add_argument('--print_freq', type=int, default=200)
    parser.add_argument('--eval_freq', type=int, default=5000)
    parser.add_argument('--save_freq', type=int, default=1000)

    # Eval
    parser.add_argument('--eval_seed', type=int, default=0)
    parser.add_argument('--eval_num_batches', type=int, default=3000)
    parser.add_argument('--eval_batch_size', type=int, default=16)
    parser.add_argument('--eval_num_samples', type=int, default=50)

    # OOD settings
    parser.add_argument('--eval_kernel', type=str, default='rbf')
    parser.add_argument('--train_kernel', type=str, default='rbf')
    parser.add_argument('--t_noise', type=float, default=None)


def build_model(cfg):
    model_name = cfg.model.name

    try:
        module = importlib.import_module(f"npf.models.{model_name}")
        model_cls = getattr(module, model_name.upper())
    except Exception as e:
        raise ValueError(f'Invalid model {model_name}')

    model = model_cls(**{k: v for k, v in cfg.model.items() if k != 'name'})
    model.cuda()
    return model


def load_checkpoint(model, checkpoint, logger, args):
    model_ckpt = torch.load(checkpoint, map_location="cuda")["model"]
    state_dict = model.state_dict()

    for k in list(model_ckpt.keys()):
        if state_dict[k].shape != model_ckpt[k].shape:
            logger.info(f"- \"{k}\" not loaded (shape not compatible)")
            model_ckpt.pop(k)
    

    model.load_state_dict(model_ckpt, strict=False)
    
    if args.freeze:
        for name, layer in model.named_children():

            if name == 'dec':
                for p in model.dec.parameters():
                    p.requires_grad = True
            elif name == 'predictor':
                for p in model.predictor.parameters():
                    p.requires_grad = True
    return model


def train(args, cfg, logger, save_dir, link_output_dir):
    if "dim_x" in cfg.model:
        cfg.model.dim_x = args.x_dim

    save_config(cfg, save_dir / "config.yaml")

    model = build_model(cfg)
    model_name = model.__class__.__name__.lower()

    torch.manual_seed(args.train_seed)
    torch.cuda.manual_seed(args.train_seed)

    data_sub_path = "infinite" if args.train_num_bs == -1 else f"finite-nb{args.train_num_bs}-bs{args.train_batch_size}"
    exp_sub_path = os.path.join(os.path.basename(__file__.split(".")[0]), f"{args.x_dim}d", model_name, data_sub_path)
    exp_path = link_output_dir(exp_sub_path)
    logger.info(f"Experiment path: \"{exp_path}\"")
    line = ' '.join(sys.argv)
    logger.info(f"code: {line}")
    
    logger.info(f"Load checkpoint from {args.checkpoint}")
    model = load_checkpoint(model, args.checkpoint, logger, args)

    if args.train_num_bs == -1:
        train_sampler = GPSampler(
            batch_size=args.train_batch_size, max_num_points=args.x_dim**2 * args.max_num_points, min_num_points=args.x_dim**2*args.min_num_points,
            seed=args.train_seed, t_noise=args.t_noise, kernel=args.train_kernel,
            x_dim=args.x_dim,
        )
    else:
        train_sampler = GPFiniteSampler(
            save_dir=f"data/gp/{args.x_dim}d/train-rbf", num_batches=args.train_num_bs,
            batch_size=args.train_batch_size, max_num_points=args.x_dim**2 * args.max_num_points, min_num_points=args.x_dim**2*args.min_num_points,
            seed=args.train_seed, t_noise=args.t_noise, kernel=args.train_kernel, loop=True,
            x_dim=args.x_dim,
        )

    eval_sampler = GPFiniteSampler(
        save_dir=f"data/gp/{args.x_dim}d/eval-{args.eval_kernel}", num_batches=args.eval_num_batches,
        batch_size=args.eval_batch_size, max_num_points=args.x_dim**2 * args.max_num_points, min_num_points=args.x_dim**2*args.min_num_points,
        seed=args.eval_seed, t_noise=args.t_noise, kernel=args.eval_kernel,
        x_dim=args.x_dim,
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_steps)

    if args.resume:
        ckpt = torch.load(save_dir / "checkpoint.pt")
        model.load_state_dict(ckpt.model)
        optimizer.load_state_dict(ckpt.optimizer)
        scheduler.load_state_dict(ckpt.scheduler)
        start_step = ckpt.step
    else:
        start_step = 1
        logger.info(f'Total number of parameters: {sum(p.numel() for p in model.parameters())}\n')

    ravg = RunningAverage()
    sample_iter = iter(train_sampler)

    with Progress(speed_estimate_period=300, disable=args.no_progress) as p:
        logger.info(eval(args, model, eval_sampler, p) + "\n")
        for step in p.trange(start_step, args.num_steps+1, description=f"{model_name.upper()}", remove=False):
            model.train()
            optimizer.zero_grad()

            batch = next(sample_iter)

            if model_name in ["anp", "banp", "danp", "mpanp"]:
                outs = model(batch, num_samples=args.train_num_samples)
            else:
                outs = model(batch)

            outs.loss.backward()
            optimizer.step()
            scheduler.step()

            for key, val in outs.items():
                ravg.update(key, val)

            if step % args.print_freq == 0 or step == args.num_steps:
                logger.info(f"step {step} lr {optimizer.param_groups[0]['lr']:.3e} [train_loss] {ravg.info()}")
                ravg.reset()

            if step % args.eval_freq == 0 or step == args.num_steps:
                logger.info(eval(args, model, eval_sampler, p) + "\n")

            if step % args.save_freq == 0 or step == args.num_steps:
                torch.save(
                    {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "step": step + 1,
                    },
                    save_dir / "checkpoint.pt",
                )


@torch.inference_mode()
def eval(args, model, eval_sampler, progress):
    model.eval()
    model_name = model.__class__.__name__.lower()

    ravg = RunningAverage()

    for batch in progress.track(eval_sampler, description="Eval", remove=True):
        if model_name in ["anp", "banp", "danp", "mpanp"]:
            outs = model(batch, args.eval_num_samples)
        else:
            outs = model(batch)

        for key, val in outs.items():
            ravg.update(key, val)

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())

    line = f'{args.eval_kernel} '
    if args.t_noise is not None:
        line += f'tn {args.t_noise} '
    line += ravg.info()
    return line


if __name__ == "__main__":
    code = launch(
        train,
        setup_argparse,
        aliases={},
    )
    exit(code)
