import argparse
from os.path import join as pjoin
import torch
import torch.nn.functional as nnf
from ivon import IVON
import sys

sys.path.append("..")
from common.utils import coro_timer, mkdirp
from common.models import STANDARDMODELS
from common.dataloaders import (
    TRAINDATALOADERS,
    TESTDATALOADER,
    NTRAIN,
    OUTCLASS,
    INSIZE,
)
from common.trainutils import (
    coro_log_timed,
    do_epoch,
    do_trainbatch,
    do_evalbatch,
    SummaryWriter,
    check_cuda,
    deteministic_run,
    savecheckpoint,
    loadcheckpoint,
)
from common.calibration import bins2acc, bins2conf, bins2ece
from common.ivon_pcm import IVONPCM
import wandb
import copy

# ---- Per-batch W&B logging helpers ----
_WB_BATCH_STEP = 0  # increases only on rank 0
_WB_BACKWARDS = 0  # increases only on rank 0
_TEMP_MODEL = None 

def _wb_log_batch(loss_value: float, extra: dict = None):
    global _WB_BATCH_STEP
    if wandb.run is None:
        return
    payload = {
        "train/batch_step": _WB_BATCH_STEP,
        "train/batch/loss": float(loss_value),
    }
    if extra:
        payload.update(extra)
    wandb.log(payload)
    _WB_BATCH_STEP += 1


def _wandb_pack(prefix, epoch, ret):
    global _WB_BACKWARDS

    bins, loss, nll, brier, acc5, ent, auroc = ret
    out = {
        "epoch": epoch,
        "backwards": _WB_BACKWARDS, 
        f"{prefix}/loss": float(loss),
        f"{prefix}/nll": float(nll),
        f"{prefix}/brier": float(brier),
        f"{prefix}/acc5": float(acc5),
        f"{prefix}/entropy": float(ent),
        f"{prefix}/auroc": float(auroc),
    }
    if bins is not None:
        out.update({
            f"{prefix}/acc": float(bins2acc(bins)),
            f"{prefix}/confidence": float(bins2conf(bins)),
            f"{prefix}/ece": float(bins2ece(bins)),
        })
    return out

class StepWindow:
    def __init__(self, period:int, duration:int, warmup:int=0, growth:float=1, Ntrain:int=50000):
        self.period, self.duration, self.warmup = period, duration, warmup
        self.idx = 0     # increments each minibatch gradient computation 
        self.active = False 
        self.switched_off = False 
        self.switched_on = False 
        self.growth = growth
        self.maxduration = int((Ntrain/50)/2)
        self.maxperiod = int(Ntrain/50)

    def is_warming_up(self) -> bool: return self.idx < self.warmup 
    def is_active(self) -> bool: return self.active
    def has_switched_on(self) -> bool: return self.switched_on
    def has_switched_off(self) -> bool: return self.switched_off
    
    def on_update(self):
        self.idx += 1
        self.switched_off = False 
        self.switched_on = False 

        if self.idx < self.warmup: 
            new_active = False 
        else: 
            t = (self.idx - self.warmup) % self.period
            new_active = 0 <= t < self.duration

        if self.active and not new_active: 
            self.switched_off = True 

        if (not self.active) and new_active: 
            self.switched_on = True

        self.active = new_active 

    def grow(self): 
        self.duration = min(self.maxduration, int(round(self.duration*self.growth)))
        self.period = min(self.maxperiod, int(round(self.period*self.growth)))


def get_args():
    parser = argparse.ArgumentParser(description="CIFAR10/100 IVON training")
    parser.add_argument(
        "arch",
        choices=STANDARDMODELS,
        help="model architecture: " + " | ".join(STANDARDMODELS),
    )
    parser.add_argument(
        "dataset",
        choices=TRAINDATALOADERS,
        help="datasets: " + " | ".join(TRAINDATALOADERS),
    )
    parser.add_argument(
        "-j",
        "--workers",
        default=0,
        type=int,
        metavar="N",
        help="number of data loading workers",
    )
    parser.add_argument(
        "-tb",
        "--tbatch",
        default=512,
        type=int,
        metavar="N",
        help="train mini-batch size",
    )
    parser.add_argument(
        "-vb",
        "--vbatch",
        default=512,
        type=int,
        metavar="N",
        help="eval mini-batch size",
    )
    parser.add_argument(
        "-sp",
        "--tvsplit",
        default=1.0,
        type=float,
        metavar="RATIO",
        help="ratio of data used for training",
    )
    parser.add_argument(
        "-e",
        "--epochs",
        default=400,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "-lr",
        "--learning_rate",
        default=1.0,
        type=float,
        metavar="LR",
        help="initial learning rate",
    )

    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "--momentum", default=0.9, type=float, metavar="M", help="momentum"
    )
    parser.add_argument(
        "-pf",
        "--printfreq",
        default=200,
        type=int,
        metavar="N",
        help="print frequency",
    )
    parser.add_argument(
        "-r",
        "--resume",
        default="",
        type=str,
        metavar="PATH",
        help="resume training from checkpoint",
    )
    parser.add_argument(
        "-d",
        "--device",
        default="cpu",
        type=str,
        metavar="DEV",
        help="run on cpu/cuda",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        help="if specified, fixes seed for reproducibility",
    )
    parser.add_argument(
        "-sd",
        "--save_dir",
        help="The directory used to save the trained models",
        default="save_temp",
        type=str,
    )
    parser.add_argument(
        "-dd",
        "--data_dir",
        help="The directory to store dataset",
        default="../data",
        type=str,
    )
    parser.add_argument(
        "-nb",
        "--bins",
        default=20,
        type=int,
        help="number of bins for ece & reliability diagram",
    )
    parser.add_argument(
        "-pd",
        "--plotdiagram",
        action="store_true",
        help="plot reliability diagram for best val",
    )
    parser.add_argument(
        "-tbd",
        "--tensorboard_dir",
        default="",
        type=str,
        help="if specified, record data for tensorboard.",
    )
    parser.add_argument("--mc_samples", default=1, type=int)
    parser.add_argument("--momentum_hess", default=0.999, type=float)
    parser.add_argument("--hess_init", default=1.0, type=float)
    parser.add_argument("--ess", default=5e4, type=float)
    parser.add_argument("--clip_radius", default=float("inf"), type=float)
    parser.add_argument("--warmup", default=5, type=int)
    parser.add_argument(
        "-opt",
        "--optimizer",
        default="ivon",
        choices=["ivon", "sgd", "adamw", "ivon_pcm"],
        type=str,
        help="optimizer to use",
    )

    parser.add_argument("--warmupfraction", default=0.2, type=float) # default: start correcting after 20%. 
    parser.add_argument("--refreshperiod", default=512, type=int) # default: compute a big batch roughly 2x per epoch
    parser.add_argument("--refreshsteps", default=256, type=int) # default: half big batch computation, half correction 
    parser.add_argument("--corralpha", default=0.6, type=float) # default: 0.6 
    parser.add_argument("--htermweight", default=1.0, type=float) # default: by default we use the hterm
    parser.add_argument("--rho1", default=1.0, type=float) # default: no outer momentum
    parser.add_argument("--rho2", default=1.0, type=float) # default: no outer momentum
    parser.add_argument("--final_lr", default=0.0, type=float) # default: anneal to zero
    parser.add_argument("--refreshgrowth", default=1.0, type=float) # growth factor for outer batch
    
    parser.add_argument("--wandb", action="store_true")
    parser.add_argument("--wandb_project", type=str, default="CIFAR10")
    parser.add_argument("--wandb_run", type=str, default=None)
    parser.add_argument("--wandb_group", type=str, default=None, help="W&B group, e.g., experiment name")

    return parser.parse_args()



def do_trainbatch_ivonpcm(batchinput, model, optimizer, window, old_model):
    global _WB_BACKWARDS

    images, target = batchinput
    loss_samples = []
    prob_samples = []

    if window.is_active(): 
        # if we just entered the window, refresh old model 
        if window.has_switched_on(): 
            print(">!", end=" ", flush=True)
            optimizer.refresh_old_model() 

        # simply accumulate gradients inside window 
        with optimizer.sampled_params(train=True):
            optimizer.zero_grad()
            output = old_model(images)
            loss = nnf.cross_entropy(output, target)
            loss.backward()
            _WB_BACKWARDS += 1

        loss_samples.append(loss.detach())
        prob_samples.append(nnf.softmax(output.detach(), -1))
    else: 
        if window.is_warming_up(): 
            with optimizer.sampled_params(train=True):
                optimizer.zero_grad()
                output = model(images)
                loss = nnf.cross_entropy(output, target)
                loss.backward()
                _WB_BACKWARDS += 1
        else: 
            # if we just left the window, collect accumulated grads
            if window.has_switched_off():
                print("!<", end=" ", flush=True)
                optimizer.collect_full_grads()

            with optimizer.sampled_params(train=True):
                optimizer.zero_grad()
                output = model(images)
                output_old = old_model(images) 
                loss = nnf.cross_entropy(output, target)
                loss_old = nnf.cross_entropy(output_old, target)
                loss.backward()
                loss_old.backward() 
                _WB_BACKWARDS += 2

        loss_samples.append(loss.detach())
        prob_samples.append(nnf.softmax(output.detach(), -1))

        optimizer.step()

    loss = torch.mean(torch.stack(loss_samples, dim=0), dim=0)
    prob = torch.mean(torch.stack(prob_samples, dim=0), dim=0)

    window.on_update() # increase counter 

    try:
        _wb_log_batch(loss.item(),
                      extra={"train/batch/svrg_active": float(window.is_active()),
                             "train/batch/svrg_warmup": float(window.is_warming_up())})
    except Exception:
        pass

    return prob, target, loss.item()

# noinspection PyShadowingNames
def do_trainbatch_ivon(batchinput, model, optimizer):
    global _WB_BACKWARDS

    images, target = batchinput
    loss_samples = []
    prob_samples = []

    for _ in range(args.mc_samples):
        with optimizer.sampled_params(train=True):
            optimizer.zero_grad(set_to_none=True)
            output = model(images)
            loss = nnf.cross_entropy(output, target)
            loss.backward()
            _WB_BACKWARDS += 1
        loss_samples.append(loss.detach())
        prob_samples.append(nnf.softmax(output.detach(), -1))

    optimizer.step()

    loss = torch.mean(torch.stack(loss_samples, dim=0), dim=0)
    prob = torch.mean(torch.stack(prob_samples, dim=0), dim=0)

    try:
        _wb_log_batch(loss.item())
    except Exception:
        pass

    return prob, target, loss.item()

def do_trainbatch_wrap(batchinput, model, optimizer):
    global _WB_BACKWARDS

    prob, target, loss = do_trainbatch(batchinput, model, optimizer)  
    _WB_BACKWARDS += 1
    try:
        _wb_log_batch(loss)
    except Exception:
        pass
    return prob, target, loss


train_functions = {
    "sgd": do_trainbatch_wrap,
    "adamw": do_trainbatch_wrap,
    "ivon": do_trainbatch_ivon,
    "ivon_pcm" : do_trainbatch_ivonpcm, 
}


def get_optimizer(args, model, old_model):
    if args.optimizer == "ivon":
        return IVON(
            model.parameters(),
            lr=args.learning_rate,
            mc_samples=args.mc_samples,
            beta1=args.momentum,
            beta2=args.momentum_hess,
            weight_decay=args.weight_decay,
            hess_init=args.hess_init,
            ess=args.ess,
            clip_radius=args.clip_radius,
        )

    elif args.optimizer == "sgd":
        return torch.optim.SGD(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )

    elif args.optimizer == "adamw":
        return torch.optim.AdamW(
            model.parameters(),
            lr=args.learning_rate,
            weight_decay=args.weight_decay,
        )

    elif args.optimizer == "ivon_pcm": 
        params = [{"params": model.parameters()}, {"params": old_model.parameters()}]
        opt = IVONPCM(params, 
            lr=args.learning_rate,
            mc_samples=args.mc_samples,
            beta1=args.momentum,
            beta2=args.momentum_hess,
            weight_decay=args.weight_decay,
            hess_init=args.hess_init,
            ess=args.ess,
            clip_radius=args.clip_radius,
            debias=True, #debias when training from epoch 0
            rescale_lr=True, 
            alpha=args.corralpha,
            h_term_weight=args.htermweight,
            rho1=args.rho1,
            rho2=args.rho2
        )
        return opt 
    
def wandb_setup(args, config_dict):
    class _Noop:
        def log(self, *_a, **_k): pass
        def finish(self): pass
        def define_metric(self, *_a, **_k): pass
        def Artifact(self, *a, **k): return None
        @property
        def run(self): return None

    if not args.wandb:
        return _Noop()

    run = wandb.init(
        project=args.wandb_project,
        name=args.wandb_run,
        group=args.wandb_group,
        config=config_dict,
    )

    wandb.define_metric("train/batch_step")
    wandb.define_metric("train/batch/*", step_metric="train/batch_step")

    wandb.define_metric("epoch")
    wandb.define_metric("test/*", step_metric="epoch")
    wandb.define_metric("train/*", step_metric="epoch")
    wandb.define_metric("time/*", step_metric="epoch")
    wandb.define_metric("lr", step_metric="epoch")


    return wandb

if __name__ == "__main__":
    timer = coro_timer()
    t_init = next(timer)
    print(f">>> Training initiated at {t_init.isoformat()} <<<\n")

    args = get_args()
    print(args, end="\n\n")

    wb_config = {
        "arch": args.arch,
        "optimizer": args.optimizer,
        "batch_size_global": args.tbatch,
        "epochs": args.epochs,
        "lr": args.learning_rate,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay,
        "mc_samples": args.mc_samples,
        "momentum_hess": args.momentum_hess,
        "hess_init": args.hess_init,
        "ess": args.ess,
        "clip_radius": args.clip_radius,
        "warmup_epochs": args.warmup,
        "printfreq": args.printfreq,
        "data_dir": args.data_dir,
        "save_dir": args.save_dir,
        "svrg": args.optimizer == "ivon_pcm",
        "svrg.refreshperiod": args.refreshperiod,
        "svrg.refreshsteps": args.refreshsteps,
        "svrg.warmup": args.warmupfraction,
        "svrg.corralpha": args.corralpha,
        "svrg.htermweight": args.htermweight,
        "seed": args.seed,
        "final_lr" : args.final_lr,
        "refreshgrowth" : args.refreshgrowth 
    }

    wb = wandb_setup(args, wb_config) 

    if args.optimizer == 'ivon_pcm': 
        ntrain = int(NTRAIN[args.dataset] * args.tvsplit)
        steps_per_epoch = max(1, (ntrain + args.tbatch - 1) // args.tbatch)
        totalsteps = int(round(steps_per_epoch * args.epochs * args.warmupfraction)) # long warmup before starting svrg
        window = StepWindow(args.refreshperiod, args.refreshsteps, totalsteps, args.refreshgrowth)

    # if seed is specified, run deterministically
    if args.seed is not None:
        deteministic_run(seed=args.seed)

    # get device for this experiment
    device = torch.device(args.device)

    if device != torch.device("cpu"):
        check_cuda()

    # build train_dir for this experiment
    mkdirp(args.save_dir)

    # resume or initialize
    if args.resume:
        startepoch, model, optimizer, scheduler, dic = loadcheckpoint(
            args.resume, device
        )
        modelargs, modelkwargs = dic["modelargs"], dic["modelkwargs"]
        print(f"resumed from {args.resume}\n")
    else:
        startepoch = 0
        modelargs, modelkwargs = (
            OUTCLASS[args.dataset],
            INSIZE[args.dataset],
        ), {}
        model = STANDARDMODELS[args.arch](*modelargs, **modelkwargs).to(
            args.device
        )
        if args.optimizer == 'ivon_pcm': 
            old_model = copy.deepcopy(model).to(args.device)
        else: 
            old_model = None 

        optimizer = get_optimizer(args, model, old_model)

        scheduler = (
            torch.optim.lr_scheduler.LinearLR(
                optimizer,
                start_factor=1.0 / args.warmup,
                end_factor=1.0,
                total_iters=args.warmup,
            )
            if args.warmup > 0
            else None
        )

    data_size = int(NTRAIN[args.dataset] * args.tvsplit)

    # try compile
    # model = torch.compile(model)

    # prep tensorboard if specified
    if args.tensorboard_dir:
        mkdirp(args.tensorboard_dir)
        sw = SummaryWriter(args.tensorboard_dir)
    else:
        sw = None

    # load data
    train_loader, val_loader = TRAINDATALOADERS[args.dataset](
        args.data_dir,
        args.tvsplit,
        args.workers,
        (device != torch.device("cpu")),
        args.tbatch,
        args.vbatch,
    )

    test_loader = TESTDATALOADER[args.dataset](
        args.data_dir,
        args.workers,
        (device != torch.device("cpu")),
        args.tbatch,
    )

    # perform training
    log_ece = coro_log_timed(sw, args.printfreq, args.bins, args.save_dir)

    print(
        f"datasize {int(data_size * args.tvsplit)}, paramsize "
        f"{sum(p.nelement() for p in model.parameters())}"
    )

    print(f">>> Training starts at {next(timer)[0].isoformat()} <<<\n")

    for e in range(startepoch, args.epochs):
        # run training part
        log_ece.send((e, "train", len(train_loader), None))
        if e == args.warmup:
            # Creating a new scheduler will already change the learning rate
            print(f"End of warmup epochs, starting cosine annealing")
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, eta_min=args.final_lr, T_max=args.epochs-args.warmup
            )
        model.train()

        if args.optimizer == 'ivon_pcm': 
            extra = {"window": window, "old_model": old_model}
        else:
            extra = {} 

        do_epoch(
            train_loader,
            train_functions[args.optimizer],
            log_ece,
            device,
            model=model,
            optimizer=optimizer,
            **extra
        )
        scheduler.step()

        if args.optimizer == 'ivon_pcm': 
            window.grow()

        train_ret = log_ece.throw(StopIteration)  # (bins, loss, nll, brier, acc5, ent, auroc)
        if args.wandb and wb.run is not None:
            wb.log(_wandb_pack("train", e, train_ret))

        # update lr scheduler and decay
        # save checkpoint
        savecheckpoint(
            pjoin(args.save_dir, "checkpoint.pt"),
            args.arch,
            modelargs,
            modelkwargs,
            model,
            optimizer,
            scheduler,
        )

        checkpoint_epochs = [0, 1, 2, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200]
        if e in checkpoint_epochs:
            savecheckpoint(
                pjoin(args.save_dir, "checkpoint%03d.pt" % (e + 1)),
                args.arch,
                modelargs,
                modelkwargs,
                model,
                optimizer,
                scheduler,
            )
        print(f'Max memory usage {torch.cuda.max_memory_allocated()}')
        time_per_epoch = next(timer)[1]
        print(f">>> Time elapsed: {time_per_epoch} <<<\n")

        # log time per epochs
        with open(pjoin(args.save_dir, "time.csv"), "a+") as file:
            file.write("%d,%f\n" % (e, time_per_epoch.total_seconds()))

            if args.wandb and wb.run is not None:
                try:
                    lr_now = scheduler.get_last_lr()[0] if hasattr(scheduler, "get_last_lr") else optimizer.param_groups[0]["lr"]
                    wb.log({
                        "epoch": e,
                        "lr": lr_now,
                        "time/epoch_sec": time_per_epoch.total_seconds(),
                    })
                except Exception as _e:
                    print(f"[W&B] log error (train epoch): {_e}")

        # run evaluation part
        log_ece.send((e, "test", len(test_loader), None))
        with torch.no_grad():
            model.eval()
            do_epoch(test_loader, do_evalbatch, log_ece, device, model=model)
        test_ret = log_ece.throw(StopIteration)  # (bins, loss, nll, brier, acc5, ent, auroc)
        if args.wandb and wb.run is not None:
                wb.log(_wandb_pack("test", e, test_ret))

        if len(val_loader) == 0:
            continue

        print(f">>> Time elapsed: {next(timer)[1]} <<<\n")

    log_ece.close()

    print(f">>> Training completed at {next(timer)[0].isoformat()} <<<\n")        
    if args.wandb and wb.run is not None:
        wb.finish()
