import os
import torch
import torch.nn as nn

from args_factory import get_args
from loaders import get_loaders
from utils import Scheduler, Statistics
from networks import get_network, fuse_BN_wrt_Flatten, add_BN_wrt_Flatten
from model_wrapper import BasicModelWrapper, PGDModelWrapper, BoxModelWrapper, HZonoModelWrapper, SmallDPModelWrapper, TAPSModelWrapper, STAPSModelWrapper, SmallBoxModelWrapper, SmallHZonoModelWrapper, GradAccuModelWrapper, MultiFacetModelWrapper, DeepPolyModelWrapper
from logging_wrapper import Neptune_SummaryWriter as SummaryWriter
from utils import write_perf_to_json, load_perf_from_json, fuse_BN, seed_everything
from tqdm import tqdm
import random
import numpy as np
from regularization import compute_fast_reg, compute_vol_reg, compute_L1_reg, compute_PI_reg, compute_neg_reg
import time
from datetime import datetime
from AIDomains.abstract_layers import Sequential
import logging
from AIDomains.zonotope import HybridZonotope

import warnings
warnings.filterwarnings("ignore")

from get_stat import PI_loop, relu_loop, test_loop

try:
    import neptune
except:
    neptune = None
# # You can manually disable neptune by setting it to None even if you installed it
neptune = None

def train_loop(model_wrapper:BasicModelWrapper, eps_scheduler:Scheduler, robust_weight_scheduler:Scheduler, train_loader, epoch_idx, optimizer, device, args, verbose:bool=False):
    model_wrapper.net.train()
    model_wrapper.summary_accu_stat = False
    model_wrapper.freeze_BN = False

    # Design of TAPS: use IBP for annealing to increase speed and performance.
    if args.use_TAPS_training:
        TAPS_grad_scale_scheduler = Scheduler(args.end_epoch_eps*len(train_loader), (args.end_epoch_eps+args.TAPS_anneal_length)*len(train_loader), 0, args.TAPS_grad_scale, "linear")
        if not args.no_ibp_anneal:
            model_wrapper.disable_TAPS = True if epoch_idx < args.end_epoch_eps else False
        else:
            model_wrapper.disable_TAPS = False
    model_wrapper.num_steps = args.train_steps

    # Design of fast regularization: use fast reg only for annealing.
    fast_reg = (args.fast_reg > 0) and epoch_idx < args.end_epoch_eps
    if fast_reg:
        model_wrapper.store_box_bounds = True

    # Define custom tracking of statistics here
    fastreg_stat, nat_accu_stat, cert_accu_stat, loss_stat = Statistics.get_statistics(4)

    # Define custom logging behavior for the first epoch here if verbose-first-epoch is set.
    if args.verbose_first_epoch and epoch_idx == 0:
        epoch_perf = {"cert_loss_curve":[], "fast_reg_curve":[], "PI_curve":[]}


    pbar = tqdm(train_loader)
    for batch_idx, (x, y) in enumerate(pbar):
        x, y = x.to(device), y.to(device)
        eps = eps_scheduler.getcurrent(epoch_idx * len(train_loader) + batch_idx)
        robust_weight = robust_weight_scheduler.getcurrent(epoch_idx * len(train_loader) + batch_idx)
        model_wrapper.robust_weight = robust_weight

        # Define batch-wise behavior for different model_wrapper
        if args.use_TAPS_training and model_wrapper.disable_TAPS==False:
            # anneal grad scale within 10 epochs
            model_wrapper.TAPS_grad_scale = TAPS_grad_scale_scheduler.getcurrent(epoch_idx * len(train_loader) + batch_idx)

        optimizer.zero_grad()
        (loss, nat_loss, cert_loss), (nat_accu, cert_accu), (is_nat_accu, is_cert_accu) = model_wrapper.compute_model_stat(x, y, eps)
        if verbose:
            print(f"Batch {batch_idx}:", nat_accu, cert_accu, loss.item())

        loss_stat.update(loss.item(), len(x))

        # Define and update additional regularization here
        if fast_reg:
            # add fast reg to the loss
            reg_eps = max(eps, args.min_eps_reg)
            if reg_eps != eps:
                # recompute bounds for fast reg
                model_wrapper.net.reset_bounds()
                abs_x = HybridZonotope.construct_from_noise(x, reg_eps, "box")
                model_wrapper.net(abs_x)
            reg = args.fast_reg * (1 - reg_eps/eps_scheduler.end_value) * compute_fast_reg(model_wrapper.net, reg_eps)
            loss = loss + reg
            fastreg_stat.update(reg.item(), len(x))
        if args.L1_reg > 0:
            loss = loss + args.L1_reg * compute_L1_reg(model_wrapper.net)


        if args.IBPR_reg > 0:
            # TODO: large performance gap between released implementation and ours
            raise NotImplementedError("IBPR is not supported in this version.")
        #     vol_reg = args.IBPR_reg * compute_vol_reg(model_wrapper.net, x, eps, recompute_box=True, min_reg_eps=args.min_eps_reg, max_reg_eps=args.test_eps) / len(x) * model_wrapper.robust_weight
        #     loss = loss + vol_reg

        # Customize the verbose-first-epoch behavior here
        if args.verbose_first_epoch and epoch_idx == 0 and ((batch_idx % int(args.verbose_gap * len(train_loader))) == 0):
            epoch_perf["cert_loss_curve"].append(loss_stat.last)
            epoch_perf["fast_reg_curve"].append(fastreg_stat.last)
            write_perf_to_json(epoch_perf, args.save_root, "first_epoch.json")


        model_wrapper.net.reset_bounds()
        loss.backward()
        model_wrapper.grad_postprocess() # can be inherited to customize gradient postprocessing
        optimizer.step()
        model_wrapper.param_postprocess()

        nat_accu_stat.update(nat_accu, len(x))
        cert_accu_stat.update(cert_accu, len(x))
        postfix_str = f"nat_accu: {nat_accu_stat.avg:.3f}, cert_accu: {cert_accu_stat.avg:.3f}, train_loss: {loss_stat.avg:.3f}"

        pbar.set_postfix_str(postfix_str)

    return nat_accu_stat.avg, cert_accu_stat.avg, eps, fastreg_stat.avg, loss_stat.avg

def get_model_wrapper(args, net, device, input_dim):
    # Define model wrapper here
    if args.multi_facets is None:
        if args.use_pgd_training:
            model_wrapper = PGDModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
        elif args.use_vanilla_ibp:
            if args.use_small_box:
                model_wrapper = SmallBoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, eps_shrinkage=args.eps_shrinkage)
            else:
                model_wrapper = BoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
        elif args.use_HBox_training:
            if args.use_small_box:
                model_wrapper = SmallHZonoModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, eps_shrinkage=args.eps_shrinkage, domain='hbox')
            else:
                model_wrapper = HZonoModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, domain='hbox')
        elif args.use_Zono_training:
            if args.use_small_box:
                model_wrapper = SmallHZonoModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, eps_shrinkage=args.eps_shrinkage, domain='zono')
            else:
                model_wrapper = HZonoModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, domain='zono')
        elif args.use_TAPS_training:
            if args.use_small_box:
                model_wrapper = STAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes, eps_shrinkage=args.eps_shrinkage)
            else:
                model_wrapper = TAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes)
        elif args.use_DP_training:
            relu_type = 'original' if args.loss_smoothing is None else f'smooth {args.loss_smoothing}'
            if args.use_small_box:
                model_wrapper = SmallDPModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, relu_type=relu_type, eps_shrinkage=args.eps_shrinkage)
            else:
                model_wrapper = DeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, relu_type=relu_type)
            # model_wrapper = MNBaBDeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
        elif args.use_DPZero_training:
            model_wrapper = DeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, relu_type='zero')
            # model_wrapper = MNBaBDeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
        elif args.use_DPBox_training:
            relu_type = 'original' if args.loss_smoothing is None else f'smooth {args.loss_smoothing}'
            if args.use_small_box:
                model_wrapper = SmallDPModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, relu_type=relu_type, use_dp_box=True, eps_shrinkage=args.eps_shrinkage)
            else:
                model_wrapper = DeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, relu_type=relu_type, use_dp_box=True)
        if args.grad_accu_batch is not None:
            model_wrapper = GradAccuModelWrapper(model_wrapper, args)
    if args.multi_facets is not None:
        assert len(args.multi_facets) >= 2, "At least two facets if multiple facet loss is used."
        wrappers = []
        for facet in args.multi_facets:
            if facet.lower() == "ibp":
                wrapper = BoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
            elif facet.lower() == "pgd":
                wrapper = PGDModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
            elif facet.lower() == "sabr":
                wrapper = SmallBoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, eps_shrinkage=args.eps_shrinkage)
            elif facet.lower() == "taps":
                wrapper = TAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes)
            elif facet.lower() == "staps":
                wrapper = STAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes, eps_shrinkage=args.eps_shrinkage)
            else:
                raise NotImplementedError(f"Unknown facet specified: {facet}.")
            wrappers.append(wrapper)
        model_wrapper = MultiFacetModelWrapper(wrappers, args.facets_eps_ratio, args.facets_weight, args)
    model_wrapper.robust_weight = 1
    return model_wrapper

def get_train_mode(args):
    # Adjust the sequence to have different preference of training when the args is ambiguous.
    if not args.multi_facets:
        if args.use_pgd_training:
            mode = "PGD_trained"
        elif args.use_vanilla_ibp:
            mode = "IBP_trained" if not args.use_small_box else "SABR_trained"
        elif args.use_HBox_training:
            mode = "HBox_trained" if not args.use_small_box else "SHBox_trained"
        elif args.use_Zono_training:
            mode = "Zono_trained" if not args.use_small_box else "SZono_trained"
        elif args.use_TAPS_training:
            mode = "TAPS_trained" if not args.use_small_box else "STAPS_trained"
        elif args.use_DP_training:
            mode = "DP_trained" if not args.use_small_box else "SDP_trained"
        elif args.use_DPZero_training:
            mode = "DPZero_trained" if not args.use_small_box else "SDPZero_trained"
        elif args.use_DPBox_training:
            mode = "DPBox_trained" if not args.use_small_box else "SDPBox_trained"
        else:
            raise NotImplementedError("Unknown training mode.")
    else:
        mode = "_".join([f"{facet}{weight}" for facet, weight in zip(args.multi_facets, args.facets_weight)])
    return mode

def parse_save_root(args, mode):
    eps_str = f"eps{args.test_eps:.5g}{'_same' if args.test_eps == args.train_eps and args.dataset=='mnist' else ''}"
    init_str = f"init_{args.init}" if args.load_model is None else f"init_pretrained_{args.init}"
    save_root = os.path.join(args.save_dir, args.dataset, eps_str, mode, args.net, init_str)
    if args.load_model is not None or args.end_epoch_eps!=20:
        save_root = os.path.join(save_root, f"eps_ann_{args.end_epoch_eps}")
    if args.fast_reg > 0:
        save_root = os.path.join(save_root, f"fast_reg_{args.fast_reg}")
    if args.use_small_box:
        save_root = os.path.join(save_root, f"lambda_{args.eps_shrinkage}")
    if (args.use_DP_training or args.use_DPBox_training) and args.loss_smoothing is not None:
        save_root = os.path.join(save_root, f"smooth_{args.loss_smoothing}")
    if args.use_TAPS_training:
        save_root = os.path.join(save_root, f"last_block_{args.block_sizes[-1]}")
    if args.IBPR_reg > 0:
        save_root = os.path.join(save_root, f"IBPR_{args.IBPR_reg}")
    if args.L1_reg > 0:
        save_root = os.path.join(save_root, f"L1_{args.L1_reg}")
    os.makedirs(save_root, exist_ok=True)
    exp_name = f"{args.dataset}__{args.net}__eps_{args.eps_end}"
    date = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
    file_name = exp_name + "__" + date
    dir_name = os.path.join(save_root, file_name)

    print('dir name',dir_name,flush=True)

    if neptune is not None:
        writer = SummaryWriter(dir_name, 'STAPS', tags=args.neptune_tags if args.neptune_tags is not None else [])
        save_root = os.path.join(save_root, writer.get_runid())

        writer.add_text('dir_name', dir_name)
        writer.add_text('file_name', file_name)
        writer.log_args(args)
    
        args.save_root = save_root
        os.makedirs(save_root, exist_ok=True)
        logging.info(f"The model will be saved at: {save_root}")
    else:
        writer = None
    
    return save_root, writer


def run(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    perf_dict = {'val_nat_curve':[], 'val_cert_curve':[], 'val_adv_curve':[], 'val_loss_curve':[], 'train_nat_curve':[], 'train_cert_curve':[], 'train_loss_curve':[], 'lr_curve':[], "eps_curve":[]}
    PI_dict = {"PI_curve":[]}
    relu_dict = {"dead_relu_curve":[], "active_relu_curve":[], "unstable_relu_curve":[]}
    reg_dict = {"fastreg_curve":[]}

    # Add more perf_dict here to track more statistics.
    perf_dict = perf_dict | PI_dict | relu_dict | reg_dict
    perf_dict["start_time"] = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
    verbose = False

    # Define dataset here
    loaders, input_size, input_channel, n_class = get_loaders(args)
    input_dim = (input_channel, input_size, input_size)
    args.num_classes = n_class
    if len(loaders) == 4:
        train_loader, val_loader, test_loader, train_test_loader = loaders
    else:
        train_loader, test_loader, train_test_loader = loaders
        val_loader = None
    perf_dict['model_selection'] = args.model_selection # tradition for certified training: use test set for model selection :(

    # Adjust the sequence to have different preference of training when the args is ambiguous.
    # Define custom behavior for different training here.
    mode = get_train_mode(args)
    
    # Schedule for robust weight
    # Assume the loss is in the form (1-w) * natural_loss + w * robust_loss
    robust_weight_scheduler = Scheduler(args.start_epoch_robust_weight*len(train_loader), args.end_epoch_robust_weight*len(train_loader), args.robust_weight_start, args.robust_weight_end, mode="linear", s=len(train_loader))

    # Schedule for input epsilon
    if args.no_anneal:
        # use const eps == train_eps
        eps_scheduler = Scheduler(args.start_epoch_eps*len(train_loader), args.end_epoch_eps*len(train_loader), args.train_eps, args.train_eps, "linear", s=len(train_loader))
    else:
        if args.schedule in ["smooth", "linear", "step"]:
            eps_scheduler = Scheduler(args.start_epoch_eps*len(train_loader), args.end_epoch_eps*len(train_loader), args.eps_start, args.eps_end, args.schedule, s=args.step_epoch*len(train_loader))
        else:
            raise NotImplementedError(f"Unknown schedule: {args.schedule}")

    # Define concrete (torch) model and convert it to abstract model here
    torch_net = get_network(args.net, args.dataset, device, init=args.init if args.load_model is None else 'default')
    # summary(net, (input_channel, input_size, input_size))
    net = Sequential.from_concrete_network(torch_net, input_dim, disconnect=False)
    net.set_dim(torch.zeros((test_loader.batch_size, *input_dim), device='cuda'))
    if args.load_model:
        net.load_state_dict(torch.load(args.load_model))
        print("Loaded:", args.load_model)
    print(net)

    # Parse save root here
    save_root, writer = parse_save_root(args, mode)

    # Define model wrapper here: this wraps how to compute the loss and how to compute the robust accuracy.
    model_wrapper = get_model_wrapper(args, net, device, input_dim)

    # Define training hyperparameter here
    param_list = set(model_wrapper.net.parameters()) - set(model_wrapper.net[0].parameters()) # exclude normalization
    lr = args.lr

    perf_dict["best_val_cert_accu"] = -1
    perf_dict["best_val_loss"] = 1e8

    if args.opt == 'adam':
        optimizer = torch.optim.Adam(param_list, lr=lr)
    elif args.opt == 'sgd':
        optimizer = torch.optim.SGD(param_list, lr=lr)
    else:
        raise ValueError(f"{args.opt} not supported.")

    lr_schedular = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_milestones, gamma=args.lr_decay_factor)


    train_time = 0.0
    for epoch_idx in range(args.n_epochs):
        print("Epoch", epoch_idx)



        #### ---- train loop below ----
        train_start_time = time.time()
        train_nat_accu, train_cert_accu, eps, fast_reg_avg, train_loss = train_loop(model_wrapper, eps_scheduler, robust_weight_scheduler, train_loader, epoch_idx, optimizer, device, args, verbose=verbose)
        time_epoch = time.time() - train_start_time
        train_time += time_epoch
        print(f"train_nat_accu: {train_nat_accu: .4f}, train_cert_accu: {train_cert_accu: .4f}, train_loss:{train_loss: .4f}")

        # Track train statistics here.
        perf_dict["fastreg_curve"].append(fast_reg_avg)
        perf_dict['train_nat_curve'].append(train_nat_accu)
        perf_dict['train_cert_curve'].append(train_cert_accu)
        perf_dict["train_loss_curve"].append(train_loss)

        # Update learning rate here.
        lr_schedular.step()
        lr = lr_schedular.get_last_lr()[0]
        perf_dict['lr_curve'].append(lr)

        eps = min(eps, args.test_eps)
        perf_dict["eps_curve"].append(eps)
        print("current eps:", eps)
        print("current robust_weight:", model_wrapper.robust_weight)

        #### ---- test loop below ----
        val_nat_accu, val_cert_accu, val_loss, val_adv_accu = test_loop(model_wrapper, eps, val_loader if val_loader is not None else test_loader, device, args, extra_adv_attack=True)
        print(f"val_nat_accu: {val_nat_accu: .4f}, val_cert_accu: {val_cert_accu: .4f}, val_adv_accu: {val_adv_accu: .4f}, val_loss:{val_loss: .4f}")
        perf_dict['val_nat_curve'].append(val_nat_accu)
        perf_dict['val_cert_curve'].append(val_cert_accu)
        perf_dict['val_adv_curve'].append(val_adv_accu)
        perf_dict["val_loss_curve"].append(val_loss)

        # #### ---- additional model statistics tracking below ----
        # -- propagation tightness --
        PI = PI_loop(model_wrapper.net, 1e-5, val_loader if val_loader is not None else test_loader, device, args.num_classes, args, relu_adjust="local")
        print(f"Propagation Tightness: {PI:.4f}")
        perf_dict["PI_curve"].append(PI)

        # -- relu status --
        dead, unstable, active = relu_loop(model_wrapper.net, max(eps, 1e-6), val_loader if val_loader is not None else test_loader, device, args)
        perf_dict["dead_relu_curve"].append(dead)
        perf_dict["unstable_relu_curve"].append(unstable)
        perf_dict["active_relu_curve"].append(active)
        print(f"Dead: {dead:.3f}; Unstable: {unstable:.3f}; Active: {active:.3f}")

        if writer is not None:
            writer.add_scalar('eps',eps,epoch_idx)
            writer.add_scalar('epoch',epoch_idx,epoch_idx)
            writer.add_scalar('robust_weight',model_wrapper.robust_weight,epoch_idx)
            
            writer.add_scalar('train_time', time_epoch, epoch_idx)
            writer.add_scalar('train_nat_acc', train_nat_accu, epoch_idx)
            writer.add_scalar('train_cert_acc', train_cert_accu, epoch_idx)
            writer.add_scalar('train_loss', train_loss, epoch_idx)
            
            writer.add_scalar('val_nat_acc', val_nat_accu, epoch_idx)
            writer.add_scalar('val_cert_acc', val_cert_accu, epoch_idx)
            writer.add_scalar('val_loss', val_loss, epoch_idx)
            
            writer.add_scalar('Prop Invariance', PI, epoch_idx)
            writer.add_scalar('Dead_ReLU', dead, epoch_idx)
            writer.add_scalar('Unstable_ReLU', unstable, epoch_idx)
            writer.add_scalar('Active_ReLU', active, epoch_idx)

        #### ---- model selection below ----
        if eps == args.test_eps:
            if (perf_dict["model_selection"] == "robust_accu" and val_cert_accu > perf_dict["best_val_cert_accu"]) or (perf_dict["model_selection"] == "loss" and val_loss < perf_dict["best_val_loss"]):
                torch.save(model_wrapper.net.state_dict(), os.path.join(save_root, "model.ckpt"))
                print("New checkpoint saved.")
                perf_dict["best_val_cert_accu"] = val_cert_accu
                perf_dict["best_val_nat_accu"] = val_nat_accu
                perf_dict["best_val_adv_accu"] = val_adv_accu
                perf_dict["best_ckpt_epoch"] = epoch_idx
                if writer is not None:
                    writer.add_scalar('best_val_cert_accu', val_cert_accu)
                    writer.add_scalar('best_val_nat_accu', val_nat_accu)
                    writer.add_scalar('best_ckpt_epoch', epoch_idx)
            perf_dict["best_val_cert_accu"] = max(perf_dict["best_val_cert_accu"], val_cert_accu)
            perf_dict["best_val_loss"] = min(perf_dict["best_val_loss"], val_loss)

        if args.save_every_epoch:
            os.makedirs(os.path.join(save_root, "Every_Epoch_Model"), exist_ok=True)
            torch.save(model_wrapper.net.state_dict(), os.path.join(save_root, "Every_Epoch_Model", f"epoch_{epoch_idx}.ckpt"))

        if perf_dict["model_selection"] is None:
            # No model selection. Save the final model.
            torch.save(model_wrapper.net.state_dict(), os.path.join(save_root, "model.ckpt"))
            print("New checkpoint saved.")

        # Write the logs to json file
        write_perf_to_json(perf_dict, save_root, "monitor.json")
        write_perf_to_json(args.__dict__, save_root, "train_args.json")

    # test for the best ckpt
    print("-"*10 + f"Model Selection: {perf_dict['model_selection']}. Testing selected checkpoint." + "-"*10)
    model_wrapper.net.load_state_dict(torch.load(os.path.join(save_root, "model.ckpt")))
    test_nat_accu, test_cert_accu, loss = test_loop(model_wrapper, args.test_eps, test_loader, device, args)
    print(f"test_nat_accu: {test_nat_accu: .4f}, test_cert_accu: {test_cert_accu: .4f}")
    perf_dict["test_nat_accu"] = test_nat_accu
    perf_dict["test_cert_accu"] = test_cert_accu
    perf_dict["train_time"] = train_time
    perf_dict["end_time"] = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
    
    write_perf_to_json(perf_dict, save_root, "monitor.json")
    write_perf_to_json(args.__dict__, save_root, "train_args.json")

    if writer is not None:
        writer.close()


def main():
    args = get_args(["basic", "train"])
    args.use_ddp = False 
    args.dist_sampler_train = False
    seed_everything(args.random_seed)
    run(args)

if __name__ == '__main__':
    main()
