import os
import os.path as osp
import sys
import time
import argparse
import importlib

from tqdm import tqdm

import torch

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from npf.data.gp import GPFiniteSampler, GPSampler
from npf.data.image import img_to_task, task_to_img
from npf.data.emnist import EMNIST
from npf.data.celeba import CelebA
from npf.utils.log import RunningAverage, get_logger

from npf.utils.launch import launch
from npf.utils.paths import evalsets_path

from nxcl.config import save_config
from nxcl.rich.progress import Progress


def setup_argparse(parser: argparse.ArgumentParser):
    # Data
    parser.add_argument('--max_num_points', type=int, default=50)
    parser.add_argument('--img_max_num_points', type=int, default=200)
    parser.add_argument('--min_num_points', type=int, default=5)
    parser.add_argument('--class_range', type=int, nargs='*', default=[0,10])

    # Train
    parser.add_argument('--train_seed', type=int, default=0)
    parser.add_argument('--train_batch_size', type=int, default=16)
    parser.add_argument('--train_batch_size_image', type=int, default=100)
    parser.add_argument('--weight_decay', type=float, default=0)

    parser.add_argument('--train_num_samples', type=int, default=4)
    parser.add_argument('--train_num_bs', type=int, default=-1)  # -1 for infinite, otherwise finite
    parser.add_argument('--lr', type=float, default=5e-4)
    parser.add_argument('--num_steps', type=int, default=100000)
    parser.add_argument('--print_freq', type=int, default=200)
    parser.add_argument('--eval_freq', type=int, default=5000)
    parser.add_argument('--save_freq', type=int, default=1000)

    # Eval
    parser.add_argument('--eval_seed', type=int, default=0)
    parser.add_argument('--eval_num_batches', type=int, default=3000)
    parser.add_argument('--eval_batch_size', type=int, default=16)
    parser.add_argument('--eval_num_samples', type=int, default=50)

    # OOD settings
    parser.add_argument('--eval_kernel', type=str, default='rbf')
    parser.add_argument('--t_noise', type=float, default=None)


def build_model(cfg):
    model_name = cfg.model.name
    if model_name in ["anp", "canp", "banp", "tnp", "tnpd", "mpanp"]:
        raise ValueError(f"{model_name} is not supported")

    try:
        module = importlib.import_module(f"npf.models.{model_name}")
        model_cls = getattr(module, model_name.upper())
    except Exception as e:
        raise ValueError(f'Invalid model {model_name}')

    model = model_cls(**{k: v for k, v in cfg.model.items() if k != 'name'})
    model.cuda()
    return model

def gen_evalset_2(args):

    torch.manual_seed(args.eval_seed)
    torch.cuda.manual_seed(args.eval_seed)

    eval_ds = CelebA(train=False)
    eval_loader = torch.utils.data.DataLoader(eval_ds,
            batch_size=args.eval_batch_size,
            shuffle=False, num_workers=4)

    batches = []
    for x, _ in tqdm(eval_loader, ascii=True):
        batches.append(img_to_task(
            x, max_num_points=args.max_num_points,
            t_noise=args.t_noise)
        )

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())

    path = osp.join(evalsets_path, 'celeba')
    if not osp.isdir(path):
        os.makedirs(path)

    filename = 'no_noise.tar' if args.t_noise is None else \
            f'{args.t_noise}.tar'
    torch.save(batches, osp.join(path, filename))

def gen_evalset_1(args):

    torch.manual_seed(args.eval_seed)
    torch.cuda.manual_seed(args.eval_seed)

    eval_ds = EMNIST(train=False, class_range=args.class_range)
    eval_loader = torch.utils.data.DataLoader(eval_ds,
            batch_size=args.eval_batch_size,
            shuffle=False, num_workers=0)

    batches = []
    for x, _ in tqdm(eval_loader, ascii=True):
        batches.append(img_to_task(
            x, max_num_points=args.img_max_num_points,
            t_noise=args.t_noise)
        )

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())

    path = osp.join(evalsets_path, 'emnist')
    if not osp.isdir(path):
        os.makedirs(path)

    c1, c2 = args.class_range
    filename = f'{c1}-{c2}'
    if args.t_noise is not None:
        filename += f'_{args.t_noise}'
    filename += '.tar'

    torch.save(batches, osp.join(path, filename))

def train(args, cfg, logger, save_dir, link_output_dir):
    model = build_model(cfg)
    model_name = model.__class__.__name__.lower()

    torch.manual_seed(args.train_seed)
    torch.cuda.manual_seed(args.train_seed)

    data_sub_path = "infinite" if args.train_num_bs == -1 else f"finite-nb{args.train_num_bs}-bs{args.train_batch_size}"
    exp_sub_path = os.path.join(os.path.basename(__file__.split(".")[0]), model_name, data_sub_path)
    exp_path = link_output_dir(exp_sub_path)
    logger.info(f"Experiment path: \"{exp_path}\"")
    line = ' '.join(sys.argv)
    logger.info(f"code: {line}")
    if args.train_num_bs == -1:
        train_sampler_2d = GPSampler(
            batch_size=args.train_batch_size, max_num_points=4 * args.max_num_points, min_num_points=4*args.min_num_points,
            seed=args.train_seed, t_noise=args.t_noise, kernel="rbf",
            x_dim=2,
        )
        train_sampler_3d = GPSampler(
            batch_size=args.train_batch_size, max_num_points=9 * args.max_num_points, min_num_points=9*args.min_num_points,
            seed=args.train_seed, t_noise=args.t_noise, kernel="rbf",
            x_dim=3,
        )
    else:
        train_sampler_2d = GPFiniteSampler(
            save_dir="data/gp/2d/train-rbf", num_batches=args.train_num_bs,
            batch_size=args.train_batch_size, max_num_points=4 * args.max_num_points, min_num_points=4*args.min_num_points,
            seed=args.train_seed, t_noise=args.t_noise, kernel="rbf", loop=True,
            x_dim=2,
        )
        train_sampler_3d = GPFiniteSampler(
            save_dir="data/gp/1d/train-rbf", num_batches=args.train_num_bs,
            batch_size=args.train_batch_size, max_num_points=9 * args.max_num_points, min_num_points=9*args.min_num_points,
            seed=args.train_seed, t_noise=args.t_noise, kernel="rbf", loop=True,
            x_dim=3,
        )

    train_ds_1 = EMNIST(train=True, class_range=args.class_range)
    train_loader_1 = torch.utils.data.DataLoader(train_ds_1,
        batch_size=args.train_batch_size_image,
        shuffle=True, num_workers=0)
    train_ds_2 = CelebA(train=True)
    train_loader_2 = torch.utils.data.DataLoader(train_ds_2,
        batch_size=args.train_batch_size,
        shuffle=True, num_workers=4)
    

    eval_sampler_1d = GPFiniteSampler(
        save_dir=f"data/gp/1d/eval-{args.eval_kernel}", num_batches=args.eval_num_batches,
        batch_size=args.eval_batch_size, max_num_points=1 * args.max_num_points, min_num_points=1*args.min_num_points,
        seed=args.eval_seed, t_noise=args.t_noise, kernel=args.eval_kernel,
        x_dim=1,
    )
    eval_sampler_2d = GPFiniteSampler(
        save_dir=f"data/gp/2d/eval-{args.eval_kernel}", num_batches=args.eval_num_batches,
        batch_size=args.eval_batch_size, max_num_points=4 * args.max_num_points, min_num_points=4*args.min_num_points,
        seed=args.eval_seed, t_noise=args.t_noise, kernel=args.eval_kernel,
        x_dim=2,
    )
    eval_sampler_3d = GPFiniteSampler(
        save_dir=f"data/gp/3d/eval-{args.eval_kernel}", num_batches=args.eval_num_batches,
        batch_size=args.eval_batch_size, max_num_points=9 * args.max_num_points, min_num_points=9*args.min_num_points,
        seed=args.eval_seed, t_noise=args.t_noise, kernel=args.eval_kernel,
        x_dim=3,
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_steps)

    if args.resume:
        ckpt = torch.load(save_dir / "checkpoint.pt")
        model.load_state_dict(ckpt.model)
        optimizer.load_state_dict(ckpt.optimizer)
        scheduler.load_state_dict(ckpt.scheduler)
        start_step = ckpt.step
    else:
        start_step = 1
        logger.info(f'Total number of parameters: {sum(p.numel() for p in model.parameters())}\n')

    ravg = RunningAverage()
    def samplers_1():
        sampler_2d = iter(train_sampler_2d)
        sampler_3d = iter(train_sampler_3d)
        # sampler_im = iter(train_loader)
        while True:
            yield next(sampler_2d)
            yield next(sampler_3d)
            # yield img_to_task(next(sampler_im)[0].cuda(), max_num_points=args.max_num_points)
    # def samplers_2():
    #     # sampler_2d = iter(train_sampler_2d)
    #     # sampler_1d = iter(train_sampler_1d)
    #     sampler_im = iter(train_loader)
    #     while True:
    #         # yield next(sampler_2d)
    #         # yield next(sampler_1d)
    #         yield img_to_task(next(sampler_im)[0].cuda(), max_num_points=args.max_num_points)

    sample_iter_1 = iter(samplers_1())
    # sample_iter_2 = iter(samplers_2())

    with Progress(speed_estimate_period=300, disable=args.no_progress) as p:
        for step in p.trange(start_step, args.num_steps+1, description=f"{model_name.upper()}", remove=False):
            if step % 900 == 2:
                iter_loader_1 = iter(train_loader_1)
                iter_loader_2 = iter(train_loader_2)
        
            model.train()
            optimizer.zero_grad()
            if step % 4 == 0 or step % 4 == 1 :
                batch = next(sample_iter_1)
            elif step % 4 == 2:
                # [16, 3, 32, 32]
                batch = img_to_task(next(iter_loader_2)[0].cuda(), max_num_points=args.max_num_points)
            else:
                # [100, 1, 28, 28]
                batch = img_to_task(next(iter_loader_1)[0].cuda(), max_num_points=args.max_num_points)
            if model_name in ["anp", "banp", "danp", "mpanp"]:
                outs = model(batch, num_samples=args.train_num_samples)
            else:
                outs = model(batch)

            outs.loss.backward()
            optimizer.step()
            scheduler.step()

            for key, val in outs.items():
                ravg.update(key, val)

            if step % args.print_freq == 0 or step == args.num_steps:
                logger.info(f"step {step} lr {optimizer.param_groups[0]['lr']:.3e} [train_loss] {ravg.info()}")
                ravg.reset()

            if step % args.eval_freq == 0 or step == args.num_steps:
                logger.info("1d: " + eval(args, model, eval_sampler_1d, p))
                logger.info("2d: " + eval(args, model, eval_sampler_2d, p))
                logger.info("3d: " + eval(args, model, eval_sampler_3d, p))
                logger.info("image_emnist: " + eval_image_1(args, model))
                logger.info("image_celeba: " + eval_image_2(args, model) + "\n")

            if step % args.save_freq == 0 or step == args.num_steps:
                torch.save(
                    {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "step": step + 1,
                    },
                    save_dir / "checkpoint.pt",
                )

@torch.inference_mode
def eval_image_2(args, model):
    # if args.mode == 'eval':
    #     ckpt = torch.load(osp.join(args.root, 'ckpt.tar'))
    #     model.load_state_dict(ckpt.model)
    #     if args.eval_logfile is None:
    #         eval_logfile = f'eval'
    #         if args.t_noise is not None:
    #             eval_logfile += f'_{args.t_noise}'
    #         eval_logfile += '.log'
    #     else:
    #         eval_logfile = args.eval_logfile
    #     filename = osp.join(args.root, eval_logfile)
    #     logger = get_logger(filename, mode='w')
    # else:
    #     logger = None

    path = osp.join(evalsets_path, 'celeba')
    if not osp.isdir(path):
        os.makedirs(path)
    filename = f'no_noise.tar' if args.t_noise is None else \
        f'_{args.t_noise}.tar'
    if not osp.isfile(osp.join(path, filename)):
        print('generating evaluation sets...')
        gen_evalset_2(args)

    eval_batches = torch.load(osp.join(path, filename))

    torch.manual_seed(args.eval_seed)
    torch.cuda.manual_seed(args.eval_seed)

    ravg = RunningAverage()
    model_name = model.__class__.__name__.lower()
    model.eval()
    with torch.no_grad():
        for batch in tqdm(eval_batches, ascii=True):
            for key, val in batch.items():
                batch[key] = val.cuda()

            if model_name in ["anp", "banp", "danp", "mpanp"]:
                outs = model(batch, args.eval_num_samples)
            else:
                outs = model(batch)

            for key, val in outs.items():
                ravg.update(key, val)

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())

    c1, c2 = args.class_range
    line = f'{model_name}: {c1}-{c2} '
    if args.t_noise is not None:
        line += f'tn {args.t_noise} '
    line += ravg.info()

    # if logger is not None:
    #     logger.info(line)

    return line

@torch.inference_mode()
def eval_image_1(args, model):
    # ckpt = torch.load(osp.join(args.root, 'ckpt.tar'))
    # model.load_state_dict(ckpt.model)
    # if args.eval_logfile is None:
    #     c1, c2 = args.class_range
    #     eval_logfile = f'eval_{c1}-{c2}'
    #     if args.t_noise is not None:
    #         eval_logfile += f'_{args.t_noise}'
    #     eval_logfile += '.log'
    # else:
    #     eval_logfile = args.eval_logfile
    # filename = osp.join(args.root, eval_logfile)
    # logger = get_logger(filename, mode='w')

    path = osp.join(evalsets_path, 'emnist')
    c1, c2 = args.class_range
    filename = f'{c1}-{c2}'
    if args.t_noise is not None:
        filename += f'_{args.t_noise}'
    filename += '.tar'
    if not osp.isfile(osp.join(path, filename)):
        print('generating evaluation sets...')
        gen_evalset_1(args)

    eval_batches = torch.load(osp.join(path, filename))

    torch.manual_seed(args.eval_seed)
    torch.cuda.manual_seed(args.eval_seed)

    ravg = RunningAverage()
    model.eval()
    model_name = model.__class__.__name__.lower()
    with torch.no_grad():
        for batch in tqdm(eval_batches, ascii=True):
            for key, val in batch.items():
                batch[key] = val.cuda()
            
            if model_name in ["np", "anp", "bnp", "banp", "danp", "tdanp", "tldanp", "dtanpd", "dttanpd","dlttanpd", "dltttanpd", "dddltttanpd", "danp", "danploss", "ddanp", "danpc"]:
                outs = model(batch, args.eval_num_samples)
            else:
                outs = model(batch)

            for key, val in outs.items():
                ravg.update(key, val)

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())

    c1, c2 = args.class_range
    line = f'{model_name}: {c1}-{c2} '
    if args.t_noise is not None:
        line += f'tn {args.t_noise} '
    line += ravg.info()
    return line

@torch.inference_mode()
def eval(args, model, eval_sampler, progress):
    model.eval()
    model_name = model.__class__.__name__.lower()

    ravg = RunningAverage()

    for batch in progress.track(eval_sampler, description="Eval", remove=True):
        if model_name in ["np", "anp", "bnp", "banp", "danp", "tdanp", "tldanp", "dtanpd", "dttanpd","dlttanpd", "dltttanpd", "danp", "danploss", "ddanp"]:
            outs = model(batch, args.eval_num_samples)
        else:
            outs = model(batch)

        for key, val in outs.items():
            ravg.update(key, val)

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())

    line = f'{args.eval_kernel} '
    if args.t_noise is not None:
        line += f'tn {args.t_noise} '
    line += ravg.info()
    return line


if __name__ == "__main__":
    code = launch(
        train,
        setup_argparse,
        aliases={},
    )
    exit(code)
