#%%
import sys
import os
sys.path.append(os.path.abspath(os.path.join(__file__, "../src")))
sys.path.append(os.path.abspath(os.path.join(__file__, "../src/record")))
from tqdm import tqdm
import datetime
import numpy as np
import random
import yaml
import time
import datetime
from typing import Optional


import torch.nn as nn
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import argparse

import logger
import tbwriter
from bort.models import give_model
from bort.datasets import give_dataset, give_dataloader
from bort.utils import give_visualizer
from bort.optimizers import give_optim, give_scheduler
from bort.utils.misc import MetricHanlder, load_all, save_all, resume_config, give_pbar, ValueMeter

def give_config(args, prefix=""):
    config = CONFIG()

    args_dict = {}
    for k in dir(args):
        if not k.startswith("_") and not k.endswith("_"):
            args_dict[k] = getattr(args, k)
            setattr(config, k, getattr(args, k))
    
    tag = "-".join([f"{k.upper()[:3]}{v}" for k, v in args_dict.items()])
    if config.save_path is None:
        config.save_path = f"./log/{prefix}{datetime.datetime.now().strftime('%m%d%H%M')}-" + tag

    if config.debug:
        config.save_path = f"./log/debug"

    for k in dir(config):
        if not k.startswith("_") and not k.endswith("_"):
            print(f"{k}: {getattr(config, k)}")
    return config, args_dict

def save_config(path, config, exclude_list=["device"]):
    if config.rank == 0:
        dict_to_save = {}
        for k in dir(config):
            if not k.startswith("_") and not k.endswith("_") and k not in exclude_list:
                dict_to_save[k] = getattr(config, k)
        with open(os.path.join(path, "config.yaml"), mode="w", encoding="utf-8") as f:
            yaml.dump(dict_to_save, f)

def force_replace_config(config, force_replace: dict = None):
    if force_replace is not None:
        for k, v in force_replace.items():
            setattr(config, k, v)
            print(f"Force {k} to be {v}")

def initiate(config, force_replace: dict = None):
    force_replace_config(config, force_replace)
    device = config.device
    resume = getattr(config, "resume", None)
    if config.is_dist:
        dist.init_process_group(
            backend="nccl",
            init_method=f"tcp://localhost:{config.dist_port}",
            world_size=config.world_size,
            rank=config.rank
        )

    os.environ["PYTHONHASHSEED"] = str(config.seed)
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed(config.seed)
    torch.backends.cudnn.deterministic = True

    logger.config(output_dir=config.save_path, dist_rank=config.rank)
    tbwriter.config(output_dir=config.save_path, dist_rank=config.rank)

    # resume config
    if resume is not None:
        try:
            resume_config(resume, config)
            force_replace_config(config, force_replace)
        except:
            save_config(config.save_path, config)
            logger.info(f"No resume! save the config to the path: {config.save_path}")
    else:
        save_config(config.save_path, config)

    train_dataset = give_dataset(config, is_train=True)
    test_dataset = give_dataset(config, is_train=False)
    train_loader = give_dataloader(config, train_dataset, is_train=True)
    test_loader = give_dataloader(config, test_dataset, is_train=False)
    model = give_model(config).to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    model, visualizer = give_visualizer(config, model)
    optimizer = give_optim(config, model)
    scheduler = give_scheduler(config, optimizer)
    metric_handler = MetricHanlder(return_epoch=True)

    if config.is_dist:
        model = DDP(model, device_ids=[config.rank])

    # resume model
    config.start_epoch = 0
    if resume is not None:
        try:
            config.start_epoch = load_all(device, resume, model, criterion, optimizer, scheduler) + 1
            logger.info(f"Epoch={config.start_epoch}, resume from {resume}")
        except Exception as e:
            import traceback
            print(f"{traceback.format_exc()}")
            logger.info(f"Resume failed! Path {resume}")
    else:
        logger.info(f"Train from scratch without resuming")
    tbwriter.set_visualizer(visualizer)

    return {
        "config": config,
        "loaders": {"train": train_loader, "test": test_loader},
        "model": model,
        "criterion": criterion,
        "optimizer": optimizer,
        "scheduler": scheduler,
        "visualizer": visualizer,
        "metric_handler": metric_handler,
        "device": device
    }

def train_one_epoch(epoch, model, loader, criterion, optimizer, scheduler, device, config):
    loss_meter = ValueMeter()
    time_meter = ValueMeter()

    model.train()
    pbar = give_pbar(loader, config.rank)
    for idx, (data, target) in enumerate(pbar):
        s_time = time.time()

        data = data.to(device)
        target = target.to(device)
        output = model(data)
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step_update(num_updates=config.train_iteration, metric=loss.item())

        e_time = time.time()
        duration = e_time - s_time
        loss_meter.update(loss)
        time_meter.update(duration)
        if idx % config.print_interval == 0:
            cur_mem = torch.cuda.max_memory_allocated(config.rank) / (1024.0 * 1024.0)
            max_mem = torch.cuda.get_device_properties(config.rank).total_memory / (1024.0 * 1024.0)

            lr = optimizer.param_groups[0]["lr"]
            logger.info(
                f"Epoch: {epoch}, "
                f"{idx}/{len(loader)}, "
                f"Eta={datetime.timedelta(seconds=int(time_meter.avg() * (len(loader)-idx)))}, "
                f"Mem={cur_mem:.0f}/{max_mem:.0f} MiB, "
                f"Loss={loss.item():.6f}({loss_meter.avg():.6f}), "
                f"Lr={lr:.3e}"
            )
            tbwriter.log.add_scalar("train/loss", loss.item(), config.train_iteration)
            tbwriter.log.add_scalar("train/lr", lr, config.train_iteration)
            config.train_iteration += 1
        
        if (idx+1) % config.stats_interval == 0:
            tbwriter.set_vis_conv(True)
        else:
            tbwriter.set_vis_conv(False)

def validate(model, loader, device, config):
    time_meter = ValueMeter()

    model.eval()
    with torch.no_grad():
        pbar = give_pbar(loader, config.rank)
        total_num = 0
        acc_num = 0
        for idx, (data, target) in enumerate(pbar):
            s_time = time.time()

            data = data.to(device)
            target = target.to(device)
            output = model(data)
            pred = torch.argmax(output, dim=-1)

            acc_num += torch.sum(pred == target).item()
            total_num += len(target)

            e_time = time.time()
            duration = e_time - s_time
            time_meter.update(duration)
            if idx % config.print_interval == 0:
                cur_mem = torch.cuda.max_memory_allocated(config.rank) / (1024.0 * 1024.0)
                max_mem = torch.cuda.get_device_properties(config.rank).total_memory / (1024.0 * 1024.0)

                logger.info(
                    f"{idx}/{len(loader)}, "
                    f"Eta={datetime.timedelta(seconds=int(time_meter.avg() * (len(loader)-idx)))}, "
                    f"Mem={cur_mem:.0f}/{max_mem:.0f} MiB"
                )
    return acc_num / total_num

def run(obj: dict):
    config = obj["config"]
    device = obj["device"]
    model = obj["model"]
    criterion = obj["criterion"]
    optimizer = obj["optimizer"]
    scheduler = obj["scheduler"]
    metric_handler = obj["metric_handler"]
    train_loader = obj["loaders"]["train"]
    test_loader = obj["loaders"]["test"]

    acc = validate(model, test_loader, device, config)
    logger.info(f"Meta acc: {acc}")
    for epoch in range(config.start_epoch, config.epochs):
        logger.info(f"Epoch {epoch}")
        logger.info("Start to train")

        train_one_epoch(epoch, model, train_loader, criterion, optimizer, scheduler, device, config)
        train_acc = validate(model, train_loader, device, config)
        tbwriter.log.add_scalar("train/accuracy", train_acc, epoch)

        val_acc = validate(model, test_loader, device, config)
        is_best, best_epoch = metric_handler.update(val_acc, epoch=epoch)

        if scheduler is not None:
            scheduler.step(epoch + 1, val_acc)

        logger.info(f"Epoch: {epoch}, TrainAcc: {train_acc:.4f}, ValAcc: {val_acc:.4f}, Best: {metric_handler.best_metric:.4f}, Best_epoch: {best_epoch}")
        if is_best:
            save_all(config.save_path, epoch, model, criterion, optimizer, scheduler)
        tbwriter.log.add_scalar("test/accuracy", val_acc, epoch)

def main_worker(device, config):
    # update device index
    config.rank = device
    config.device = torch.device(f"cuda:{device}")
    objects = initiate(config)
    run(objects)

def main(config):
    device = config.device
    if isinstance(device, int) or len(device) == 1: # Single device
        config.is_dist = False
        device = device[0] if isinstance(device, list) else device
        main_worker(device, config)
    else: # Multi devices
        config.is_dist = True
        config.world_size = len(device)
        mp.spawn(main_worker, nprocs=len(device), args=(config,))

class CONFIG:
    # dataset
    # model
    # training
    train_iteration = 0
    nw = 8
    # optimizer
    momen = 0.9
    betas = (0.9, 0.999)
    eps = 1e-8
    # general setting
    print_interval = 50
    stats_interval = 1e8 # 50
    seed = 42
    is_dist = False
    world_size = 1
    dist_port = 12345
    rank = 0


#%%
if __name__ == "__main__":
    parser = argparse.ArgumentParser("bort")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--device", type=int, nargs="+", default=[0])
    parser.add_argument("--save_path", type=str, default=None)
    # dataset
    parser.add_argument("--dataset", type=str, default="cifar10", 
                        help="cifar10 / mnist / imagenet")
    # model
    parser.add_argument("--model", type=str, default="simple", 
                        help="simple / lenetfc / lenet / lenetc / nomaxnetfc / resnet50(nn.Conv2d) / resnet18(nn.Conv2d) / vgg16")
    parser.add_argument("--recon_ratio", type=float, default=0.95)
    parser.add_argument("--setting", type=str, default=None, 
                        help="It only applies to AllConv12 model")
    parser.add_argument("--act_type", type=str, default="guided",
                        help="It only applies to AllConv12 model: guided / leaky")
    # optimizer
    parser.add_argument("--optim", type=str, default="bort",
                        help="sgd / adamw / bort / abort")
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--wc", type=float, default=0.001)
    parser.add_argument("--wd", type=float, default=0.01)
    parser.add_argument("--scheduler", type=str, default=None, help="cosine / none")
    parser.add_argument("--warmup_lr", type=float, default=1e-6)
    parser.add_argument("--min_lr", type=float, default=1e-7)
    parser.add_argument("--warmup_epochs", type=int, default=5)
    # training
    parser.add_argument("--bs", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=20)

    opt = parser.parse_args()
    config, args_dict = give_config(opt)
    main(config)