import argparse
import logging
import torch
import numpy as np
import random
import sys
sys.path.append("../../mlp")
from Models import MLP
import sys
import os 
from load_data import load_data_mlp
from scipy.io import  savemat
from Train_and_Test import Train, Test
import random
import wandb
from torch.optim.lr_scheduler import _LRScheduler
from dst_scheduler import DSTScheduler
from torch import optim

class WarmUpLR(_LRScheduler):
    """warmup_training learning rate scheduler
    Args:
        optimizer: optimzier(e.g. SGD)
        total_iters: totoal_iters of warmup phase
    """
    def __init__(self, optimizer, total_iters, last_epoch=-1):

        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """we will use the first m batches, and set the learning
        rate to base_lr * m / total_iters
        """
        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]

class metrics(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.loss = []
        self.top1 = []
        self.top5 = []
        self.layer1_overlap_rate = []
        self.layer2_overlap_rate = []
        self.layer3_overlap_rate = []

    def update(self, loss, top1, top5):
        self.loss.append(loss)
        self.top1.append(top1)
        self.top5.append(top5)

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def Args():
    parser = argparse.ArgumentParser()
    # normal training arguments
    parser.add_argument("--learning_rate", type=float, default=0.01)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--dataset", type=str, help='including'
        "CIFAR10, CIFAR100, EMNIST, Fashion_MNIST, MNIST, FER2013, SVHN, tiny-imagenet-ori, tiny-imagenet-crop, tiny-imagenet-resize, OxfordFlowers, Caltech256, OxfordIIITPet, INaturalist.")
    parser.add_argument("--network_structure", type=str, default="mlp") #就是architecture
    parser.add_argument("--cnn_structure", type=str, default = "D")
    parser.add_argument("--cuda_device", type=str)
    parser.add_argument("--check_exist", action="store_true", help="checking the experiments whether has been done or not")
    parser.add_argument("--dropout", type=float, default=0, help="dropout rate")
    parser.add_argument("--optimizer", type=str, default="sgd", help="adam, sgd...")
    parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight_decay")
    
    
    # sparse training arguments
    # parser.add_argument("--epsilon", type=float, default=0.0, help="give the sparsity to each layer by ER")
    parser.add_argument("--sparsity", type=float, default=0.99, help="directly give the sparsity to each layer")
    parser.add_argument("--update_interval", type=int, default=1, help="the number of intervals for topology evolution")
    parser.add_argument("--zeta", type=float, default=0.3, help="the fraction of removal and regrown links")
    parser.add_argument("--adaptive_zeta", action="store_true", help="add this augment to make the zeta reducing across the epochs")
    parser.add_argument("--remove_method", type=str, default="weight_magnitude", help="how to remove links, Magnitude or MEST")
    parser.add_argument("--regrow_method", type=str, default="random", help="how to regrow new links. "
                                                                            "Including: random, gradient, CH3_L3p_soft, CH3_L3n_soft, CH2_L3n_soft")
    parser.add_argument("--init_mode", type=str, default="kaiming", help="how to initialize the weights of the model."
                                                                         "Including: kaiming, swi")
    parser.add_argument("--chain_removal", action="store_true", help="use forward removal and backward removal")
    parser.add_argument("--self_correlated_sparse", action="store_true")
    parser.add_argument("--soft_self_correlated_sparse", action="store_true")
    parser.add_argument("--dim", type=int, default=1)
    parser.add_argument("--dimension", type=int, default=0)
    parser.add_argument("--WS", action="store_true")
    parser.add_argument("--DNM", action="store_true")
    parser.add_argument("--cross", action="store_true")
    # parser.add_argument("--BA", action="store_true")
    parser.add_argument("--random_rewiring", type=float, default=1.0)
    parser.add_argument("--M", type=int, help="number of dendrites")
    parser.add_argument("--M_dist", type=str, help="Distribution of the number of dendrites among neurons", choices=["fixed", "gaussian", "uniform", "spatial_gaussian", "spatial_inversegaussian"], default="fixed")
    parser.add_argument("--gamma", type=float, help="dendritic spreading parameter")
    parser.add_argument("--gamma_dist", type=str, default = "fixed", choices=["fixed", "gaussian", "uniform", "spatial_gaussian", "spatial_inversegaussian"])
    parser.add_argument("--degree_dist", type=str, default = "fixed", choices=["fixed", "gaussian", "uniform", "spatial_gaussian", "spatial_inversegaussian"])
    parser.add_argument("--synaptic_dist", type=str, help="distribution of the nodes in the synapses", default="fixed", choices=["fixed", "gaussian", "uniform", "spatial_gaussian", "spatial_inversegaussian"])
    parser.add_argument("--BHI", action="store_true", help="Bipartite Hyperbolic Initialisation")
    parser.add_argument("--CWS", action="store_true", help="Cannistraci Watts Strogatz")
    parser.add_argument("--WS1", action="store_true", help="First variation of WS initialisation")
    parser.add_argument("--WS2", action="store_true", help="Second variation of WS initialisation")
    parser.add_argument("--WS3", action="store_true", help="Third variation of WS initialisation")    
    parser.add_argument("--delta", type=float, help="δ locality parameter (0 ≤ δ ≤ 1)")
    parser.add_argument("--F", type=str, help="distribution of δ over the neurons", default="fixed")
    parser.add_argument("--sigma_x", type=float, help="standard deviation over the x")
    parser.add_argument("--sigma_y", type=float, help="standard deviation over the y")
    parser.add_argument("--rho", type=float, help="correlation")
    parser.add_argument("--QHI", action="store_true", help="General Hyperbolic Initialisation")
    parser.add_argument("--BHI_T", type=float, default=0.0, help="nPSO temperature. If 0: purely greedy by hyperbolic distance; >0: probabilistic sampling.")
    parser.add_argument("--BHI_gamma", type=float, default=2.0, help="Power-law exponent gamma controlling the radial coordinate dynamics in nPSO.") 
    parser.add_argument("--BHI_distr", type=int, default=0, help=(    
        "angular coordinate distribution:"
        "  0                uniform on [0,2π) (PSO model)"
        "  C>0              integer # of equidistant GaussianMixture components"
        "  GaussianMixture  sklearn GaussianMixture object for custom GM"
        "  (angles, probs, centers)  3-tuple of lists for fully custom mixture"
        )
    )
    parser.add_argument("--rewire_mode", type=str, choices=["none", "uniform", "random"], default="none", help=(
        "Optional bipartite rewiring after nPSO wiring: "
        "'none' = keep original edges; "
        "'uniform' = redistribute B-endpoint degrees as evenly as possible; "
        "'random' = reassign B endpoints at random (preserving A-degrees)."
        )
    )
    parser.add_argument("--degree_allocation", action="store_true", help= "Enable the spatial-sorting degree allocation strategy: for each bipartite block, compute each neuron's degree and then reorder them so that the highest-degree neuron sits at the center, the next two occupy the centers of the two halves, the next four the centers of the four quarters, and so on. For the third sandwich only the input side is sorted (the final output layer remains unconstrained).")
    parser.add_argument("--no_log", action="store_true")
    parser.add_argument("--linearlr", action="store_true")
    parser.add_argument('--milestone', type=int, nargs='+', default=[60, 120, 160],
                        help='Decrease learning rate at these epochs.')
    parser.add_argument("--iterative_warmup_steps", type=int, default=0)
    parser.add_argument("--warmup", action="store_true")
    parser.add_argument("--discretelr", action="store_true")
    parser.add_argument("--end_factor", type=float, default=0.01)
    parser.add_argument("--T_decay", type=str, default="no_decay", choices=["no_decay", "linear"], help="decay the temperature of the sampling")
    parser.add_argument("--decay_factor", type=float, default=0.9, help="decay epochs of the learning rate")
    parser.add_argument("--dst_scheduler",action="store_true", help="use the dst scheduler")
    parser.add_argument("--itop", action="store_true", help="use the itop logger")
    parser.add_argument("--EM_S", action="store_true")
    parser.add_argument("--early_stop", action="store_true")
    parser.add_argument("--early_stop_thre", type=float, default=0.9)
    parser.add_argument("--record_anp", action="store_true")
    parser.add_argument("--method", type=str, default="", help="the method name of the experiment")
    parser.add_argument("--old_version", action="store_true")
    parser.add_argument("--history_weights", action="store_true")
    parser.add_argument("--tiedrank", action="store_true")
    parser.add_argument("--start_T", type=float, help="set the initial value of delta", default=1.0)
    parser.add_argument("--end_T", type=float, help="set the final value of delta", default = 3.0)
    parser.add_argument("--factor", type=float, default=0.01)
    parser.add_argument("--ssam", action="store_true")
    parser.add_argument("--function", type=str, help="varies alpha according to some function", choices=["linear", "sigmoid", "tempering_hard", "tempering_strength", "piecewise_constant_linear", "piecewise_constant_abrupt", "decreasing_step", "fixed_height_step", "wave", "decaying_wave", "cosine_annealing_wr"], default = "linear")
    parser.add_argument("--k", type=float, help="constant that regulates the sigmoid functions's steepness", default = 1)
    parser.add_argument("--granet", action="store_true")
    parser.add_argument("--granet_init_sparsity", type=float, default=0.9)
    parser.add_argument("--granet_init_epoch", type=int, default=0)
    parser.add_argument("--gmp", action="store_true")
    parser.add_argument("--pruning_scheduler", type=str, default="none", help="none, linear, granet, s_shape")
    parser.add_argument("--pruning_method", type=str, default="none", help="ri, weight_magnitude, MEST")
    parser.add_argument("--sparsity_distribution", type=str, default="uniform", help="uniform, non-uniform")

    args=parser.parse_args()
    
    if "CIFAR" in args.dataset:
        args.dim=1
    else: args.dim=2

    return args

def get_save_path_origion(args):
    save_path_parts = [
    args.network_structure,
    args.dataset,
    f"s_{args.seed}_lr_{args.learning_rate}_e_{args.epochs}",
    f"s_{args.sparsity}_i_{args.update_interval}_z_{args.zeta}_df_{args.decay_factor}",
    ]
    if not args.method:
        if args.adaptive_zeta:
            save_path_parts.append("az_")
        if args.init_mode == "swi":
            save_path_parts.append("swi_")
        if args.init_mode == "kaiming":
            save_path_parts.append("kaiming_")

        if args.self_correlated_sparse:
            save_path_parts.append("scs_")
        
        if args.soft_self_correlated_sparse:
            save_path_parts.append("sscs_")

        if args.WS:
            save_path_parts.append(f"random_rewiring_{args.random_rewiring}_")

        if args.DNM:
            save_path_parts.append(f"degreedist_{args.degree_dist}/M_{args.M}_Mdist_{args.M_dist}/gamma_{args.gamma}_gammadist_{args.gamma_dist}/synapticdist_{args.synaptic_dist}_")

        if args.cross:
            save_path_parts.append(f"cross_{args.random_rewiring}")
        
        if args.BHI:
            save_path_parts.append(f"BHI_T_{args.BHI_T}/gamma_{args.BHI_gamma}_dist_{args.BHI_distr}/rm_{args.rewire_mode}_")
            if args.degree_allocation:
                save_path_parts.append(f"degree_allocation_")

        if args.QHI:
            save_path_parts.append(f"QHI_T_{args.BHI_T}/gamma_{args.BHI_gamma}_dist_{args.BHI_distr}_")
            if args.degree_allocation:
                save_path_parts.append(f"degree_allocation_")
        
        if args.CWS:
            save_path_parts.append(f"CWS_sigmax_{args.sigma_x}/sigmay_{args.sigma_y}_rho_{args.rho}_random_rewiring_{args.random_rewiring}_")

        if args.WS1:
            save_path_parts.append(f"WS1_deltadist_{args.delta_dist}_delta_{args.delta}_random_rewiring_{args.random_rewiring}_")

        if args.WS2:
            save_path_parts.append(f"WS2_deltadist_{args.delta_dist}_delta_{args.delta}_random_rewiring_{args.random_rewiring}_")

        if args.WS3:
            save_path_parts.append(f"WS3_deltadist_{args.delta_dist}_delta_{args.delta}")

        if args.chain_removal:
            save_path_parts.append("chain_")

        if args.early_stop:
            save_path_parts.append("es_")
        
        if args.remove_method.split("_")[-1] == "soft":
            save_path_parts.append(f"{args.T_decay}_")
    
        if args.start_T != args.end_T:
            save_path_parts.append(f"st_{args.start_T}_et_{args.end_T}_{args.function}")

        if args.start_T == args.end_T:
            save_path_parts.append(f"t_{args.start_T}_fixed")
        
        if args.function == "sigmoid":
            save_path_parts.append(f"k_{args.k}")

        # Adding fixed parts
        save_path_parts.append(f"d_{args.dim}_")
        save_path_parts.append(f"dist_{args.sparsity_distribution}_")
        if args.history_weights:
            save_path_parts.append("hw_")
        if args.gmp:
            save_path_parts.append(f"gmp_{args.granet_init_sparsity}_{args.pruning_scheduler}_{args.pruning_method}_")
        elif args.regrow_method == "fc":
            save_path_parts.append(f"fc_")
        else:
            if args.granet:
                save_path_parts.append(f"granet_{args.granet_init_sparsity}_{args.pruning_scheduler}_{args.pruning_method}_")
            elif args.adaptive_zeta:
                save_path_parts.append("az_")
            elif args.EM_S:
                save_path_parts.append("EM_S_")
            elif args.remove_method == "MEST" or args.remove_method == "history_gradient":
                save_path = "/".join(save_path_parts) +  f"/{args.method}_{args.regrow_method}_{args.remove_method}_{args.factor}/"
            else:
                save_path_parts.append(f"fix_")
        save_path_parts.append(f"{args.regrow_method}_{args.remove_method}")
        if args.ssam:
            print("Using SSAM!!!!")
            save_path_parts.append("ssam_")
        # Joining all parts together to form the save path
        save_path = "/".join(save_path_parts) + "/"
    else:
        if args.self_correlated_sparse:
            save_path_parts.append("scs_")
        if args.soft_self_correlated_sparse:
            save_path_parts.append("sscs_")
        if args.WS:
            save_path_parts.append(f"random_rewiring_{args.random_rewiring}_")
        if args.DNM:
            save_path_parts.append(f"degreedist_{args.degree_dist}/M_{args.M}_Mdist_{args.M_dist}/gamma_{args.gamma}_gammadist_{args.gamma_dist}/synapticdist_{args.synaptic_dist}_")
        if args.cross:
            save_path_parts.append(f"cross_{args.random_rewiring}")
        if args.BHI:
            save_path_parts.append(f"T_{args.BHI_T}/gamma_{args.BHI_gamma}_dist_{args.BHI_distr}/rm_{args.rewire_mode}_")
            if args.degree_allocation:
                save_path_parts.append(f"degree_allocation_")

        if args.QHI:
            save_path_parts.append(f"QHI_T_{args.BHI_T}/gamma_{args.BHI_gamma}_dist_{args.BHI_distr}_")
            if args.degree_allocationH:
                save_path_parts.append(f"degree_allocation_")

        if args.CWS:
            save_path_parts.append(f"CWS_sigmax_{args.sigma_x}/sigmay_{args.sigma_y}_rho_{args.rho}_random_rewiring_{args.random_rewiring}_")

        if args.WS1:
            save_path_parts.append(f"WS1_deltadist_{args.delta_dist}_delta_{args.delta}_random_rewiring_{args.random_rewiring}_")

        if args.EM_S:
            save_path_parts.append("EM_S_")

        if args.start_T != args.end_T:
            save_path_parts.append(f"st_{args.start_T}_et_{args.end_T}_{args.function}")

        if args.start_T == args.end_T:
            save_path_parts.append(f"t_{args.start_T}_fixed")
            
        if args.function == "sigmoid":
            save_path_parts.append(f"k_{args.k}")

        save_path_parts.append(f"d_{args.dim}_")
        if args.gmp:
            save_path_parts.append(f"gmp_{args.granet_init_sparsity}_{args.pruning_scheduler}_{args.pruning_method}_")
        elif args.regrow_method == "fc":
            save_path_parts.append(f"fc_")
            save_path = "/".join(save_path_parts) +  f"/{args.method}_{args.regrow_method}/"
        else:
        
            if args.adaptive_zeta:
                save_path_parts.append("az_")
            elif args.EM_S:
                save_path_parts.append("EM_S_")

            elif args.granet:
                save_path_parts.append(f"granet_{args.granet_init_sparsity}_{args.pruning_scheduler}_{args.pruning_method}_")

            save_path_parts.append(f"st_{args.start_T}_et_{args.end_T}")
       
            
            if args.remove_method == "MEST" or args.remove_method == "history_gradient":
                save_path = "/".join(save_path_parts) +  f"/{args.method}_{args.regrow_method}_{args.remove_method}_{args.factor}/"
            else:
                save_path = "/".join(save_path_parts) +  f"/{args.method}_{args.regrow_method}_{args.remove_method}/"
            
        if args.ssam:
            print("Using SSAM!!!!")
            save_path += "ssam_/"

    return save_path

def get_save_path(args):
    return os.path.join(f"/home/{str(os.environ["USERNAME"])}/snn_conversion/input",args.network_structure.upper(),args.dataset,f's_{args.sparsity}')

def train_model(seed, device, args):
    setup_seed(seed)
    print(args)
    save_path=get_save_path(args)

    print("Save path is:", save_path)
    
    os.makedirs(save_path, exist_ok=True)

    if args.check_exist:
        if os.path.exists(save_path + "res.mat"):
            print("This simulation has already finished!!!!")
            return

    '''
    run_name = save_path
    run = wandb.init(
        # Set the project where this run will be logged
        project="{0}_{1}".format(args.dataset, args.network_structure),
        name=run_name + args.regrow_method,
        # Track hyperparameters and run metadata
        config=vars(args),
        mode="disabled" if args.no_log else "online",
    )
    '''

    if args.network_structure.lower() == "mlp":
        train_loader, test_loader, indim, outdim, hiddim = load_data_mlp(args.dataset,args.batch_size,args.dim)

        model = MLP(indim, hiddim, outdim, args.dropout).to(device)

    if args.adaptive_zeta or args.EM_S:
        T_end = args.epochs * 0.75
    else:
        T_end = args.epochs
        # print(model)

    if args.old_version:
        T_end = args.epochs
    # optimizer
    optimizer = optim.SGD(model.parameters(), lr = args.learning_rate, momentum = 0.9, weight_decay=args.weight_decay)
    if args.linearlr:
        train_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=args.end_factor, total_iters=int(args.epochs * args.decay_factor))
    elif args.discretelr:
        train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestone, gamma=0.2) #learning rate decay
        
    if args.warmup:
        iter_per_epoch = len(train_loader)
        warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch)
    else:
        warmup_scheduler = None

    if args.dst_scheduler:
        pruner = DSTScheduler(model, optimizer, alpha=args.zeta, delta=args.update_interval * len(train_loader), sparsity_distribution=args.sparsity_distribution, static_topo=False, T_end=T_end* len(train_loader), ignore_linear_layers=False, grad_accumulation_n=1, args=args)
    else:
        pruner = None
    m = metrics()
  
  
    itop_rates = []
    anp_rates = []
    best_acc = 0.0
    for epoch in range(args.epochs):
        anp_rate = 1.0
        itop_rate = 0.0
        Train(args, model, device, train_loader, optimizer, epoch, warmup_scheduler, pruner)
        top1, top5, test_loss= Test(model, device, test_loader)
        m.update(test_loss, top1, top5)

        if top1 > best_acc:
            # Only save the best results after sparsity reached the preset value
            if (not args.EM_S and not args.granet) or epoch > args.epochs * 0.75:
                best_acc = top1
                torch.save(model.state_dict(), os.path.join(save_path,'best_model.pth'))
        
        

        # print(optimizer.param_groups[0]['lr'])
        if args.discretelr:
            train_scheduler.step(epoch)
        elif args.linearlr:
            train_scheduler.step()
        for param_group in optimizer.param_groups:
            current_lr = param_group['lr']
            print("Current learning rate is:", current_lr)
        

        if args.itop:
            for l in range(len(pruner.record_mask)):
                itop_rate += (torch.sum(pruner.record_mask[l]) / pruner.record_mask[l].numel())/len(pruner.record_mask)
            itop_rates.append(itop_rate.item())

        if args.record_anp:
            # record the active neuron post-training rate
            active_neurons = 0
            total_neurons = 0
            for l, mask in enumerate(pruner.backward_masks):
                active_neurons += torch.sum(torch.sum(mask, dim=0) > 0)
                total_neurons += mask.shape[1]

            active_neurons += torch.sum(torch.sum(mask, dim=1) > 0)
            total_neurons += mask.shape[0]
            anp_rate = active_neurons / total_neurons
            anp_rates.append(anp_rate.item())
            print("Active neurons percentage is:", anp_rate.item())

        
        #wandb.log({"test_accuracy": top1, "test_loss": test_loss, "itop_rate": itop_rate, "anp_rate": anp_rate})
    # save model
    savemat(os.path.join(save_path,'res.mat'), {'top1':m.top1, 'top5':m.top5, 'loss':m.loss, "itop_rate": itop_rates, "anp_rate": anp_rates})


if __name__ == '__main__':
    args = Args()
    print(args.cuda_device)
    torch.cuda.set_device(args.cuda_device)
    device =torch.device(args.cuda_device)
    
    print("using GPU: ", torch.cuda.get_device_name())
    logging.basicConfig(
        filename=f"/home/{str(os.environ["USERNAME"])}/snn_conversion/ann/train100/train100.log",
        level=logging.INFO,
        format="%(asctime)s %(levelname)s %(message)s",
    )

    try:
        train_model(args.seed, device, args)
    except Exception as e:
        logging.exception("exception in main")
        sys.exit(1)






