import os
import os.path as osp
import sys
import time
import argparse
import importlib

from tqdm import tqdm

import torch

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from npf.data.gp import GPFiniteSampler, GPSampler
from npf.data.image import img_to_task, task_to_img
from npf.data.emnist import EMNIST, EMNIST_Partial
from npf.data.celeba import CelebA, CelebA_Partial
from npf.utils.log import RunningAverage, get_logger

from npf.utils.launch import launch
from npf.utils.paths import evalsets_path

from nxcl.config import save_config
from nxcl.rich.progress import Progress


def setup_argparse(parser: argparse.ArgumentParser):
    parser.add_argument('-c', '--checkpoint', type=str, required=True)
    parser.add_argument('--freeze', action='store_true')

    # Data
    parser.add_argument('--max_num_points', type=int, default=50)
    parser.add_argument('--img_max_num_points', type=int, default=200)
    parser.add_argument('--min_num_points', type=int, default=5)
    parser.add_argument('--class_range', type=int, nargs='*', default=[0,10])
    parser.add_argument('--x_dim', type=int, default=2)

    # Train
    parser.add_argument('--train_seed', type=int, default=0)
    parser.add_argument('--train_batch_size', type=int, default=16)
    parser.add_argument('--train_batch_size_image', type=int, default=100)
    parser.add_argument('--weight_decay', type=float, default=0)

    parser.add_argument('--train_num_samples', type=int, default=4)
    parser.add_argument('--train_num_bs', type=int, default=-1)  # -1 for infinite, otherwise finite
    parser.add_argument('--lr', type=float, default=5e-4)
    parser.add_argument('--num_steps', type=int, default=100000)
    parser.add_argument('--print_freq', type=int, default=200)
    parser.add_argument('--eval_freq', type=int, default=5000)
    parser.add_argument('--save_freq', type=int, default=1000)
    parser.add_argument('--data_type', type=str, default='emnist')

    # Eval
    parser.add_argument('--eval_seed', type=int, default=0)
    parser.add_argument('--eval_num_batches', type=int, default=3000)
    parser.add_argument('--eval_batch_size', type=int, default=16)
    parser.add_argument('--eval_num_samples', type=int, default=50)

    # OOD settings
    parser.add_argument('--eval_kernel', type=str, default='rbf')
    parser.add_argument('--t_noise', type=float, default=None)


def build_model(cfg, data_type):
    model_name = cfg.model.name

    try:
        module = importlib.import_module(f"npf.models.{model_name}")
        model_cls = getattr(module, model_name.upper())
    except Exception as e:
        raise ValueError(f'Invalid model {model_name}')
    
    if data_type == 'celeba':
        cfg.model['dim_x'] = 2
        cfg.model['dim_y'] = 3
    
    model = model_cls(**{k: v for k, v in cfg.model.items() if k != 'name'})
    model.cuda()
    return model

def load_checkpoint(model, checkpoint, logger, args):
    model_ckpt = torch.load(checkpoint, map_location="cuda")["model"]
    state_dict = model.state_dict()

    for k in list(model_ckpt.keys()):
        if state_dict[k].shape != model_ckpt[k].shape:
            logger.info(f"- \"{k}\" not loaded (shape not compatible)")
            model_ckpt.pop(k)
    

    model.load_state_dict(model_ckpt, strict=False)
    
    if args.freeze:
        for name, layer in model.named_children():

            if name == 'dec':
                for p in model.dec.parameters():
                    p.requires_grad = True
            elif name == 'predictor':
                for p in model.predictor.parameters():
                    p.requires_grad = True
    return model

def gen_evalset(args):

    torch.manual_seed(args.eval_seed)
    torch.cuda.manual_seed(args.eval_seed)

    eval_ds = EMNIST(train=False, class_range=args.class_range)
    eval_loader = torch.utils.data.DataLoader(eval_ds,
            batch_size=args.eval_batch_size,
            shuffle=False, num_workers=0)

    batches = []
    for x, _ in tqdm(eval_loader, ascii=True):
        batches.append(img_to_task(
            x, max_num_points=args.img_max_num_points,
            t_noise=args.t_noise)
        )

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())
    
    path = osp.join(evalsets_path, 'emnist')
    if not osp.isdir(path):
        os.makedirs(path)

    c1, c2 = args.class_range
    filename = f'{c1}-{c2}'
    if args.t_noise is not None:
        filename += f'_{args.t_noise}'
    filename += '.tar'

    torch.save(batches, osp.join(path, filename))

def gen_evalset_2(args):

    torch.manual_seed(args.eval_seed)
    torch.cuda.manual_seed(args.eval_seed)

    eval_ds = CelebA(train=False)
    eval_loader = torch.utils.data.DataLoader(eval_ds,
            batch_size=args.eval_batch_size,
            shuffle=False, num_workers=4)

    batches = []
    for x, _ in tqdm(eval_loader, ascii=True):
        batches.append(img_to_task(
            x, max_num_points=args.max_num_points,
            t_noise=args.t_noise)
        )

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())

    path = osp.join(evalsets_path, 'celeba')
    if not osp.isdir(path):
        os.makedirs(path)

    filename = 'no_noise.tar' if args.t_noise is None else \
            f'{args.t_noise}.tar'
    torch.save(batches, osp.join(path, filename))
    
def train(args, cfg, logger, save_dir, link_output_dir):
    if "dim_x" in cfg.model:
        cfg.model.dim_x = args.x_dim

    save_config(cfg, save_dir / "config.yaml")

    
    model = build_model(cfg, args.data_type)
    model_name = model.__class__.__name__.lower()

    torch.manual_seed(args.train_seed)
    torch.cuda.manual_seed(args.train_seed)

    data_sub_path = "infinite" if args.train_num_bs == -1 else f"finite-nb{args.train_num_bs}-bs{args.train_batch_size}"
    exp_sub_path = os.path.join(os.path.basename(__file__.split(".")[0]), model_name, data_sub_path)
    exp_path = link_output_dir(exp_sub_path)
    logger.info(f"Experiment path: \"{exp_path}\"")
    line = ' '.join(sys.argv)
    logger.info(f"code: {line}")

    logger.info(f"Load checkpoint from {args.checkpoint}")
    model = load_checkpoint(model, args.checkpoint, logger, args)

    if args.train_num_bs == -1:
        train_sampler_2d = GPSampler(
            batch_size=args.train_batch_size, max_num_points=4 * args.max_num_points, min_num_points=4*args.min_num_points,
            seed=args.train_seed, t_noise=args.t_noise, kernel="rbf",
            x_dim=2,
        )
        train_sampler_1d = GPSampler(
            batch_size=args.train_batch_size, max_num_points=1 * args.max_num_points, min_num_points=1*args.min_num_points,
            seed=args.train_seed, t_noise=args.t_noise, kernel="rbf",
            x_dim=1,
        )
    else:
        train_sampler_2d = GPFiniteSampler(
            save_dir="data/gp/2d/train-rbf", num_batches=args.train_num_bs,
            batch_size=args.train_batch_size, max_num_points=4 * args.max_num_points, min_num_points=4*args.min_num_points,
            seed=args.train_seed, t_noise=args.t_noise, kernel="rbf", loop=True,
            x_dim=2,
        )
        train_sampler_1d = GPFiniteSampler(
            save_dir="data/gp/1d/train-rbf", num_batches=args.train_num_bs,
            batch_size=args.train_batch_size, max_num_points=1 * args.max_num_points, min_num_points=1*args.min_num_points,
            seed=args.train_seed, t_noise=args.t_noise, kernel="rbf", loop=True,
            x_dim=1,
        )
        
    if args.data_type == 'emnist':
        train_ds = EMNIST_Partial(train=True, class_range=args.class_range, num_bs=args.train_batch_size_image*args.train_num_bs)
        train_loader = torch.utils.data.DataLoader(train_ds,
            batch_size=args.train_batch_size_image,
            shuffle=True, num_workers=0)
    
    elif args.data_type == 'celeba':
        train_ds = CelebA_Partial(train=True, num_bs=args.train_batch_size_image*args.train_num_bs)
        train_loader = torch.utils.data.DataLoader(train_ds,
            batch_size=args.train_batch_size_image,
            shuffle=True, num_workers=0)
    

    eval_sampler_1d = GPFiniteSampler(
        save_dir=f"data/gp/1d/eval-{args.eval_kernel}", num_batches=args.eval_num_batches,
        batch_size=args.eval_batch_size, max_num_points=1 * args.max_num_points, min_num_points=1*args.min_num_points,
        seed=args.eval_seed, t_noise=args.t_noise, kernel=args.eval_kernel,
        x_dim=1,
    )
    eval_sampler_2d = GPFiniteSampler(
        save_dir=f"data/gp/2d/eval-{args.eval_kernel}", num_batches=args.eval_num_batches,
        batch_size=args.eval_batch_size, max_num_points=4 * args.max_num_points, min_num_points=4*args.min_num_points,
        seed=args.eval_seed, t_noise=args.t_noise, kernel=args.eval_kernel,
        x_dim=2,
    )
    eval_sampler_3d = GPFiniteSampler(
        save_dir=f"data/gp/3d/eval-{args.eval_kernel}", num_batches=args.eval_num_batches,
        batch_size=args.eval_batch_size, max_num_points=9 * args.max_num_points, min_num_points=9*args.min_num_points,
        seed=args.eval_seed, t_noise=args.t_noise, kernel=args.eval_kernel,
        x_dim=3,
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_steps)

    if args.resume:
        ckpt = torch.load(save_dir / "checkpoint.pt")
        model.load_state_dict(ckpt.model)
        optimizer.load_state_dict(ckpt.optimizer)
        scheduler.load_state_dict(ckpt.scheduler)
        start_step = ckpt.step
    else:
        start_step = 1
        logger.info(f'Total number of parameters: {sum(p.numel() for p in model.parameters())}\n')

    ravg = RunningAverage()
    def samplers_1():
        # sampler_2d = iter(train_sampler_2d)
        sampler_1d = iter(train_sampler_1d)
        # sampler_im = iter(train_loader)
        while True:
            # yield next(sampler_2d)
            yield next(sampler_1d)
            # yield img_to_task(next(sampler_im)[0].cuda(), max_num_points=args.max_num_points)
    # def samplers_2():
    #     # sampler_2d = iter(train_sampler_2d)
    #     # sampler_1d = iter(train_sampler_1d)
    #     sampler_im = iter(train_loader)
    #     while True:
    #         # yield next(sampler_2d)
    #         # yield next(sampler_1d)
    #         yield img_to_task(next(sampler_im)[0].cuda(), max_num_points=args.max_num_points)

    sample_iter_1 = iter(samplers_1())
    # sample_iter_2 = iter(samplers_2())


    with Progress(speed_estimate_period=300, disable=args.no_progress) as p:
        for step in p.trange(start_step, args.num_steps+1, description=f"{model_name.upper()}", remove=False):
            if step % args.train_num_bs == 1:
                iter_loader = iter(train_loader)
        
            model.train()
            optimizer.zero_grad()

            batch = img_to_task(next(iter_loader)[0].cuda(), max_num_points=args.max_num_points)
            
            if model_name in ["anp", "banp", "danp", "mpanp"]:
                outs = model(batch, num_samples=args.train_num_samples)
            else:
                outs = model(batch)

            outs.loss.backward()
            optimizer.step()
            scheduler.step()

            for key, val in outs.items():
                ravg.update(key, val)

            if step % args.print_freq == 0 or step == args.num_steps:
                logger.info(f"step {step} lr {optimizer.param_groups[0]['lr']:.3e} [train_loss] {ravg.info()}")
                ravg.reset()

            if step % args.eval_freq == 0 or step == args.num_steps:
                if args.data_type == 'emnist':
                    logger.info("image: " + eval_image(args, model) + "\n")
                elif args.data_type == 'celeba':
                    logger.info("image: " + eval_image_2(args, model) + "\n")

            if step % args.save_freq == 0 or step == args.num_steps:
                torch.save(
                    {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "step": step + 1,
                    },
                    save_dir / "checkpoint.pt",
                )

@torch.inference_mode()
def eval_image(args, model):
    # ckpt = torch.load(osp.join(args.root, 'ckpt.tar'))
    # model.load_state_dict(ckpt.model)
    # if args.eval_logfile is None:
    #     c1, c2 = args.class_range
    #     eval_logfile = f'eval_{c1}-{c2}'
    #     if args.t_noise is not None:
    #         eval_logfile += f'_{args.t_noise}'
    #     eval_logfile += '.log'
    # else:
    #     eval_logfile = args.eval_logfile
    # filename = osp.join(args.root, eval_logfile)
    # logger = get_logger(filename, mode='w')

    path = osp.join(evalsets_path, 'emnist')
    c1, c2 = args.class_range
    filename = f'{c1}-{c2}'
    if args.t_noise is not None:
        filename += f'_{args.t_noise}'
    filename += '.tar'
    if not osp.isfile(osp.join(path, filename)):
        print('generating evaluation sets...')
        gen_evalset(args)

    eval_batches = torch.load(osp.join(path, filename))

    torch.manual_seed(args.eval_seed)
    torch.cuda.manual_seed(args.eval_seed)

    ravg = RunningAverage()
    model.eval()
    model_name = model.__class__.__name__.lower()
    with torch.no_grad():
        for batch in tqdm(eval_batches, ascii=True):
            for key, val in batch.items():
                batch[key] = val.cuda()
            
            if model_name in ["anp", "banp", "danp", "mpanp"]:
                outs = model(batch, args.eval_num_samples)
            else:
                outs = model(batch)

            for key, val in outs.items():
                ravg.update(key, val)

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())

    c1, c2 = args.class_range
    line = f'{model_name}: {c1}-{c2} '
    if args.t_noise is not None:
        line += f'tn {args.t_noise} '
    line += ravg.info()
    return line

@torch.inference_mode
def eval_image_2(args, model):
    # if args.mode == 'eval':
    #     ckpt = torch.load(osp.join(args.root, 'ckpt.tar'))
    #     model.load_state_dict(ckpt.model)
    #     if args.eval_logfile is None:
    #         eval_logfile = f'eval'
    #         if args.t_noise is not None:
    #             eval_logfile += f'_{args.t_noise}'
    #         eval_logfile += '.log'
    #     else:
    #         eval_logfile = args.eval_logfile
    #     filename = osp.join(args.root, eval_logfile)
    #     logger = get_logger(filename, mode='w')
    # else:
    #     logger = None

    path = osp.join(evalsets_path, 'celeba')
    if not osp.isdir(path):
        os.makedirs(path)
    filename = f'no_noise.tar' if args.t_noise is None else \
        f'_{args.t_noise}.tar'
    if not osp.isfile(osp.join(path, filename)):
        print('generating evaluation sets...')
        gen_evalset_2(args)

    eval_batches = torch.load(osp.join(path, filename))

    torch.manual_seed(args.eval_seed)
    torch.cuda.manual_seed(args.eval_seed)

    ravg = RunningAverage()
    model_name = model.__class__.__name__.lower()
    model.eval()
    with torch.no_grad():
        for batch in tqdm(eval_batches, ascii=True):
            for key, val in batch.items():
                batch[key] = val.cuda()

            if model_name in ["anp", "banp", "danp", "mpanp"]:
                outs = model(batch, args.eval_num_samples)
            else:
                outs = model(batch)

            for key, val in outs.items():
                ravg.update(key, val)

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())

    c1, c2 = args.class_range
    line = f'{model_name}: {c1}-{c2} '
    if args.t_noise is not None:
        line += f'tn {args.t_noise} '
    line += ravg.info()

    # if logger is not None:
    #     logger.info(line)

    return line

@torch.inference_mode()
def eval(args, model, eval_sampler, progress):
    model.eval()
    model_name = model.__class__.__name__.lower()

    ravg = RunningAverage()

    for batch in progress.track(eval_sampler, description="Eval", remove=True):
        if model_name in ["anp", "banp", "danp", "mpanp"]:
            outs = model(batch, args.eval_num_samples)
        else:
            outs = model(batch)

        for key, val in outs.items():
            ravg.update(key, val)

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())

    line = f'{args.eval_kernel} '
    if args.t_noise is not None:
        line += f'tn {args.t_noise} '
    line += ravg.info()
    return line


if __name__ == "__main__":
    code = launch(
        train,
        setup_argparse,
        aliases={},
    )
    exit(code)
