import os
import torch
import torch.nn as nn
from args_factory import get_args
from loaders import get_loaders
from utils import Scheduler, Statistics
from networks import get_network, fuse_BN_wrt_Flatten, add_BN_wrt_Flatten
from model_wrapper import BasicModelWrapper, PGDModelWrapper, BoxModelWrapper, HZonoModelWrapper, TAPSModelWrapper, STAPSModelWrapper, SmallBoxModelWrapper, SmallHZonoModelWrapper, DeepPolyModelWrapper, SmallDPModelWrapper, GradAccuModelWrapper, MultiFacetModelWrapper, MNBaBDeepPolyModelWrapper
from utils import write_perf_to_json, load_perf_from_json, fuse_BN, seed_everything
from tqdm import tqdm
import random
import numpy as np
from regularization import compute_fast_reg, compute_vol_reg, compute_L1_reg, compute_PI_reg, compute_neg_reg
import time
from datetime import datetime
from AIDomains.abstract_layers import Sequential
import logging
from AIDomains.zonotope import HybridZonotope
import math

import warnings
warnings.filterwarnings("ignore")

try:
    import neptune
except:
    neptune = None
# # You can manually disable neptune by setting it to None even if you installed it
# neptune = None

from get_stat import PI_loop, relu_loop, test_loop
from evo_wrapper import PGPEEvoWrapper, GeneticsEvoWrapper, BasicEvoWrapper
# from rso_wrapper import RSOWrapper

def evo_train_loop(evo_wrapper: BasicEvoWrapper, epoch_idx, args):
    model_wrapper: BasicModelWrapper = evo_wrapper.model_wrapper
    model_wrapper.net.train()
    model_wrapper.summary_accu_stat = True
    model_wrapper.freeze_BN = False
    model_wrapper.num_steps = args.train_steps

    # evo_wrapper.num_batches=130
    evo_wrapper.train_one_epoch(epoch_idx)



def get_model_wrapper(args, net, device, input_dim):
    # Define model wrapper here
    if args.multi_facets is None:
        if args.use_pgd_training:
            model_wrapper = PGDModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
        elif args.use_vanilla_ibp:
            if args.use_small_box:
                model_wrapper = SmallBoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, eps_shrinkage=args.eps_shrinkage)
            else:
                model_wrapper = BoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
        elif args.use_TAPS_training:
            if args.use_small_box:
                model_wrapper = STAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes, eps_shrinkage=args.eps_shrinkage)
            else:
                model_wrapper = TAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes)
        elif args.use_DP_training:
            #model_wrapper = DeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
            if args.use_small_box:
                model_wrapper = SmallDPModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, eps_shrinkage=args.eps_shrinkage)
            else:
                model_wrapper = MNBaBDeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
        elif args.use_DPBox_training:
            model_wrapper = DeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, use_dp_box=True)
        if args.grad_accu_batch is not None:
            model_wrapper = GradAccuModelWrapper(model_wrapper, args)
    if args.multi_facets is not None:
        assert len(args.multi_facets) >= 2, "At least two facets if multiple facet loss is used."
        wrappers = []
        for facet in args.multi_facets:
            if facet.lower() == "ibp":
                wrapper = BoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
            elif facet.lower() == "pgd":
                wrapper = PGDModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
            elif facet.lower() == "sabr":
                wrapper = SmallBoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, eps_shrinkage=args.eps_shrinkage)
            elif facet.lower() == "taps":
                wrapper = TAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes)
            elif facet.lower() == "staps":
                wrapper = STAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes, eps_shrinkage=args.eps_shrinkage)
            else:
                raise NotImplementedError(f"Unknown facet specified: {facet}.")
            wrappers.append(wrapper)
        model_wrapper = MultiFacetModelWrapper(wrappers, args.facets_eps_ratio, args.facets_weight, args)
    model_wrapper.robust_weight = 1
    return model_wrapper

def get_train_mode(args):
    # Adjust the sequence to have different preference of training when the args is ambiguous.
    if not args.multi_facets:
        if args.use_pgd_training:
            mode = "PGD_trained"
        elif args.use_vanilla_ibp:
            mode = "IBP_trained" if not args.use_small_box else "SABR_trained"
        elif args.use_TAPS_training:
            mode = "TAPS_trained" if not args.use_small_box else "STAPS_trained"
        elif args.use_DP_training:
            mode = "DP_trained"
        elif args.use_DPBox_training:
            mode = "DPBox_trained"
        else:
            raise NotImplementedError("Unknown training mode.")
    else:
        mode = "_".join([f"{facet}{weight}" for facet, weight in zip(args.multi_facets, args.facets_weight)])
    return mode

def parse_save_root(args, mode):
    eps_str = f"eps{args.test_eps:.5g}"
    init_str = f"init_{args.init}" if args.load_model is None else f"init_pretrained"
    save_root = os.path.join(args.save_dir, args.dataset, eps_str, mode, args.net, init_str)
    if args.fast_reg > 0:
        save_root = os.path.join(save_root, f"fast_reg_{args.fast_reg}")
    if args.use_small_box:
        save_root = os.path.join(save_root, f"lambda_{args.eps_shrinkage}")
    if args.use_TAPS_training:
        save_root = os.path.join(save_root, f"last_block_{args.block_sizes[-1]}")
    if args.IBPR_reg > 0:
        save_root = os.path.join(save_root, f"IBPR_{args.IBPR_reg}")
    if args.L1_reg > 0:
        save_root = os.path.join(save_root, f"L1_{args.L1_reg}")
    if args.neptune_id is not None:
        save_root = os.path.join(save_root, f"{args.neptune_id}")
    os.makedirs(save_root, exist_ok=True)
    logging.info(f"The model will be saved at: {save_root}")
    return save_root

def run(args):
    
    # neptune logging: placed here to enable remote interruption.
    if neptune is not None:
        nep_log = None
    else:
        nep_log = None
        neptune_id = None
    args.neptune_id = neptune_id

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    perf_dict = {'val_nat_curve':[], 'val_cert_curve':[], 'val_loss_curve':[], 'train_nat_curve':[], 'train_cert_curve':[], 'train_loss_curve':[], 'lr_curve':[], "eps_curve":[]}
    PI_dict = {"PI_curve":[]}
    relu_dict = {"dead_relu_curve":[], "active_relu_curve":[], "unstable_relu_curve":[]}
    reg_dict = {"fastreg_curve":[]}
    adv_dict = {"train_adv_accu_curve":[], "val_adv_accu_curve":[]}

    # Add more perf_dict here to track more statistics.
    perf_dict = perf_dict | PI_dict | relu_dict | reg_dict | adv_dict
    perf_dict["start_time"] = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
    perf_dict["neptune_id"] = neptune_id
    verbose = False

    # Define dataset here
    loaders, input_size, input_channel, n_class = get_loaders(args, num_workers=0)
    input_dim = (input_channel, input_size, input_size)
    args.num_classes = n_class
    if len(loaders) == 4:
        train_loader, val_loader, test_loader, train_test_loader = loaders
        
    else:
        train_loader, test_loader, train_test_loader = loaders
        val_loader = None
    perf_dict['model_selection'] = args.model_selection # tradition for certified training: use test set for model selection :(

    # Define train mode (name of this train) here
    mode = get_train_mode(args)

    # Schedule for input epsilon
    if args.no_anneal:
        # use const eps == train_eps
        eps_scheduler = Scheduler(args.start_epoch_eps*len(train_loader), args.end_epoch_eps*len(train_loader), args.train_eps, args.train_eps, "linear", s=len(train_loader))
    else:
        if args.schedule in ["smooth", "linear", "step"]:
            eps_scheduler = Scheduler(args.start_epoch_eps*len(train_loader), args.end_epoch_eps*len(train_loader), args.eps_start, args.eps_end, args.schedule, s=args.step_epoch*len(train_loader))
        else:
            raise NotImplementedError(f"Unknown schedule: {args.schedule}")

    # Define concrete (torch) model and convert it to abstract model here
    torch_net = get_network(args.net, args.dataset, device, init=args.init)
    # summary(net, (input_channel, input_size, input_size))
    net = Sequential.from_concrete_network(torch_net, input_dim, disconnect=False, retain_concrete=True)
    net.set_dim(torch.zeros((test_loader.batch_size, *input_dim), device='cuda'))
    if args.load_model:
        net.load_state_dict(torch.load(args.load_model))
        print("Loaded:", args.load_model)
    print(net)

    # Parse save root here
    save_root = parse_save_root(args, mode)
    args.save_root = save_root

    # Define model wrapper here: this wraps how to compute the loss and how to compute the robust accuracy.
    model_wrapper = get_model_wrapper(args, net, device, input_dim)

    # x, y = next(iter(train_loader))
    # x, y = x.to(device), y.to(device)
    # model_wrapper.get_DP_bounds(x-args.train_eps, x+args.train_eps, y, return_bounds=True)
    # model_wrapper.get_robust_stat_from_bounds(x-args.train_eps, x+args.train_eps, y, intermediate_bounds=None, compute_bounds=True)

    # TODO: def evo wrapper here
    if args.use_PGPE_evo:
        evo_wrapper = PGPEEvoWrapper(train_loader, torch_net, model_wrapper, eps_scheduler, args.lr, std_lr=args.lr_std, num_actors=args.num_actors, device=device, args=args, popsize=args.popsize, std_init=args.std_init, use_current_std=args.use_current_std, std_min=args.std_min, std_max=args.std_max, subbatch_size=None, nep_log=nep_log, optimizer=args.opt)
    elif args.use_GA_evo:
        mp_scheduler = None if args.start_epoch_std is None or args.end_epoch_std is None else Scheduler(args.start_epoch_std, args.end_epoch_std, 1, args.std_min/args.std_init, "log_linear")
        # psr_scheduler = None if args.start_epoch_psr is None or args.end_epoch_psr is None else Scheduler(args.start_epoch_psr, args.end_epoch_psr, args.psr_init, args.psr_min, "step", s=len(train_loader))
        # scheduler epochs would be multiplied by the number of iterations through the dataset per epoch
        evo_wrapper = GeneticsEvoWrapper(train_loader, torch_net, model_wrapper, eps_scheduler, popsize=args.popsize, num_elites=args.num_elites, num_parents=args.num_parents, mutation_power=args.std_init, use_current_std=args.use_current_std, pert_space_ratio=args.psr_init, mp_scheduler=mp_scheduler, psr_scheduler=None, popsize_scheduler=None, num_actors=args.num_actors, device=device, args=args, subbatch_size=None, nep_log=nep_log)
    else:
        raise NotImplementedError("Must specify an evolutional wrapper for using evo_train.")
    # evo_wrapper = RSOWrapper(model_wrapper, args.lr_milestones, args.lr_decay_factor, train_loader, eps_scheduler)

    perf_dict["best_val_cert_accu"] = -1
    perf_dict["best_val_loss"] = 1e8

    if nep_log is not None:
        nep_log["args"] = args.__dict__

    train_time = 0.0
    for epoch_idx in range(args.n_epochs):
        print("Epoch", epoch_idx + 1, "/", args.n_epochs)

        #### ---- train loop below ----
        train_start_time = time.time()
        evo_train_loop(evo_wrapper, epoch_idx, args)
        train_time += time.time() - train_start_time

        eps = eps_scheduler.getcurrent(len(train_loader)*(epoch_idx+1))
        eps = min(eps, args.test_eps)

        train_nat_accu, train_cert_accu, train_loss, train_adv_accu = test_loop(model_wrapper, eps, train_test_loader, device, args, extra_adv_attack=True, max_batches=args.max_batches_train_eval)
        print(f"train_nat_accu: {train_nat_accu: .4f}, train_cert_accu: {train_cert_accu: .4f}, train_loss:{train_loss: .4f}, train_adv_accu:{train_adv_accu: .4f}")

        # Track train statistics here.
        perf_dict['train_nat_curve'].append(train_nat_accu)
        perf_dict['train_cert_curve'].append(train_cert_accu)
        perf_dict["train_loss_curve"].append(train_loss)
        perf_dict["train_adv_accu_curve"].append(train_adv_accu)



        # Update learning rate here.
        evo_wrapper.step_lr()
        # lr = evo_wrapper.lr_scheduler.get_last_lr()[0]
        # perf_dict['lr_curve'].append(lr)

        perf_dict["eps_curve"].append(eps)
        print("current eps:", eps)
        print("current robust_weight:", model_wrapper.robust_weight)

        #### ---- test loop below ----
        val_nat_accu, val_cert_accu, val_loss, val_adv_accu = test_loop(model_wrapper, eps, val_loader if val_loader is not None else test_loader, device, args, extra_adv_attack=True, max_batches=args.max_batches)
        print(f"val_nat_accu: {val_nat_accu: .4f}, val_cert_accu: {val_cert_accu: .4f}, val_loss:{val_loss: .4f}, val_adv_accu:{val_adv_accu: .4f}")
        perf_dict['val_nat_curve'].append(val_nat_accu)
        perf_dict['val_cert_curve'].append(val_cert_accu)
        perf_dict["val_loss_curve"].append(val_loss)
        perf_dict["val_adv_accu_curve"].append(val_adv_accu)

        # #### ---- additional model statistics tracking below ----
        # # -- propagation tightness --
        # PI = PI_loop(model_wrapper.net, 1e-5, val_loader if val_loader is not None else test_loader, device, args.num_classes, args, relu_adjust="local")
        # print(f"Propagation Tightness: {PI:.4f}")
        # perf_dict["PI_curve"].append(PI)

        # # -- relu status --
        # dead, unstable, active = relu_loop(model_wrapper.net, max(eps, 1e-6), val_loader if val_loader is not None else test_loader, device, args)
        # perf_dict["dead_relu_curve"].append(dead)
        # perf_dict["unstable_relu_curve"].append(unstable)
        # perf_dict["active_relu_curve"].append(active)
        # print(f"Dead: {dead:.3f}; Unstable: {unstable:.3f}; Active: {active:.3f}")

        if nep_log is not None:
            nep_log["train_nat_accu_curve"].append(train_nat_accu)
            nep_log["train_cert_accu_curve"].append(train_cert_accu)
            nep_log["train_loss_curve"].append(train_loss)
            nep_log["val_nat_accu_curve"].append(val_nat_accu)
            nep_log["val_cert_accu_curve"].append(val_cert_accu)
            nep_log["val_loss_curve"].append(val_loss)
            nep_log["eps_curve"].append(eps)
            nep_log["epoch"].append(epoch_idx)
            nep_log["train_adv_accu_curve"].append(train_adv_accu)
            nep_log["val_adv_accu_curve"].append(val_adv_accu)

        #### ---- model selection below ----
        if eps == args.test_eps:
            if (perf_dict["model_selection"] == "robust_accu" and val_cert_accu > perf_dict["best_val_cert_accu"]) or (perf_dict["model_selection"] == "loss" and val_loss < perf_dict["best_val_loss"]):
                torch.save(model_wrapper.net.state_dict(), os.path.join(save_root, "model.ckpt"))
                if nep_log is not None:
                    nep_log["model"].upload(os.path.join(save_root, "model.ckpt"))
                    nep_log["best_val_cert_accu"].append(val_cert_accu)
                    nep_log["best_val_nat_accu"].append(val_nat_accu)
                print("New checkpoint saved.")
                perf_dict["best_ckpt_epoch"] = epoch_idx
            perf_dict["best_val_cert_accu"] = max(perf_dict["best_val_cert_accu"], val_cert_accu)
            perf_dict["best_val_loss"] = min(perf_dict["best_val_loss"], val_loss)

        if args.save_every_epoch:
            os.makedirs(os.path.join(save_root, "Every_Epoch_Model"), exist_ok=True)
            torch.save(model_wrapper.net.state_dict(), os.path.join(save_root, "Every_Epoch_Model", f"epoch_{epoch_idx}.ckpt"))

        if perf_dict["model_selection"] is None:
            # No model selection. Save the final model.
            torch.save(model_wrapper.net.state_dict(), os.path.join(save_root, "model.ckpt"))
            if nep_log is not None:
                nep_log["model"].upload(os.path.join(save_root, "model.ckpt"))
            print("New checkpoint saved.")

        # Write the logs to json file
        write_perf_to_json(perf_dict, save_root, "monitor.json")
        write_perf_to_json(args.__dict__, save_root, "train_args.json")

    # test for the best ckpt
    print("-"*10 + f"Model Selection: {perf_dict['model_selection']}. Testing selected checkpoint." + "-"*10)
    model_wrapper.net.load_state_dict(torch.load(os.path.join(save_root, "model.ckpt")))
    test_nat_accu, test_cert_accu, loss = test_loop(model_wrapper, args.test_eps, test_loader, device, args, max_batches=args.max_batches)
    print(f"test_nat_accu: {test_nat_accu: .4f}, test_cert_accu: {test_cert_accu: .4f}")
    perf_dict["test_nat_accu"] = test_nat_accu
    perf_dict["test_cert_accu"] = test_cert_accu
    perf_dict["train_time"] = train_time
    perf_dict["end_time"] = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
    
    write_perf_to_json(perf_dict, save_root, "monitor.json")
    write_perf_to_json(args.__dict__, save_root, "train_args.json")

    # if nep_log is not None:
    #     # neptune logging: save the logs in the end to enable quick web visualization. Logs during the run can be found locally in monitor.json
    #     nep_log["args"] = args.__dict__
    #     nep_log["result"] = perf_dict
    #     for k, v in perf_dict.items():
    #         if isinstance(v, list) and len(v)>0:
    #             for i in v:
    #                 nep_log[f"result/{k}"].append(i)
    #     nep_log["model"].upload(os.path.join(save_root, "model.ckpt"))
    if nep_log is not None:
        nep_log.stop()


def main():
    args = get_args(["basic", "train", "evo"])
    seed_everything(args.random_seed)
    run(args)

if __name__ == '__main__':
    main()
