import torch, os, logging
import functions.loss_f as loss_f
import numpy as np
from torch.nn.utils import clip_grad_norm_
import tools.global_v as glv
from tools.global_v import device
from tools.attack import fast_gradient_method, projected_gradient_descent, bim_attack, r_fgsm, gaussian_noise_attack
from tools.apgd import APGD
import torch.distributed as dist
import wandb
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
import torch.nn.functional as F

def print_rank0(*args, **kwargs):
    try:
        if dist.get_rank() == 0:
            print(*args, **kwargs)
    except:
        print(*args, **kwargs)

def print_tau(net, indices):
    parameters = list(net.named_parameters())  # 获取所有参数
    tau_dict = {}
    for i in indices:
        name, param = parameters[i]  # 获取名称和参数
        mean_val = param.data.mean().item()  # 计算平均值
        min_val = param.data.min().item()  # 计算最小值
        max_val = param.data.max().item()  # 计算最大值

        print_rank0(f'Layer {name},  Mean: {mean_val:.3f},  Min:  {min_val:.3f}, Max:  {max_val:.3f} ')

        # 创建一个字典以保存这个层的数据
        tau_dict[name+'_mean'] = mean_val
        tau_dict[name+'_min'] = min_val
        tau_dict[name+'_max'] = max_val
    return tau_dict

def wandb_init_rank0(*args, **kwargs):
    if dist.get_rank() == 0:
        wandb.init(*args, **kwargs)

def wandb_log_rank0(*args, **kwargs):
    if dist.get_rank() == 0:
        wandb.log(*args, **kwargs)

def init_distributed_mode():
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        rank = int(os.environ['SLURM_PROCID'])
        gpu = rank % torch.cuda.device_count()
    else:
        print_rank0('Not using distributed mode')
        return

    torch.cuda.set_device(gpu)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    assert torch.distributed.is_initialized()

def set_target_signal(n_steps, network_config):
    # set target signal
    if n_steps >= 10:
        desired_spikes = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1]).repeat(int(n_steps/10))
    else:
        #desired_spikes = torch.tensor([0, 1, 1, 1, 1]).repeat(int(n_steps/5))
        desired_spikes = torch.cat((torch.tensor([0]), torch.ones(n_steps-1)))

    desired_spikes = desired_spikes.view(1, 1, 1, 1, n_steps).to(device)
    desired_spikes = loss_f.psp(desired_spikes, network_config).view(1, 1, 1, n_steps)
    
    return desired_spikes

def calculate_loss(outputs, labels, network_config, layers_config, err, desired_spikes):
    loss_function = network_config['loss']
    
    if loss_function == "count":
        desired_count = network_config['desired_count']
        undesired_count = network_config['undesired_count']

        targets = torch.ones(outputs.shape[0], outputs.shape[1], 1, 1).to(device) * undesired_count
        for i in range(len(labels)):
            targets[i, labels[i], ...] = desired_count

        return err.spike_count(outputs, targets, network_config, layers_config[list(layers_config.keys())[-1]])

    elif loss_function == "kernel":

        targets = torch.zeros(outputs.shape).to(device)
        for i in range(len(labels)):
            targets[i, labels[i], ...] = desired_spikes

        return err.spike_kernel(outputs, targets, network_config)

    elif loss_function == "softmax":
        return err.spike_soft_max(outputs, labels)

    else:
        raise Exception('Unrecognized loss functbion.')

def calculate_cam_loss(outputs, labels, network_config, layers_config, err, desired_spikes):
    loss_function = network_config['loss']

    if loss_function == "kernel":

        mask = torch.zeros(outputs.shape).to(device)
        for i in range(len(labels)):
            mask[i, labels[i], ...] = 1
        out = torch.sum(outputs*mask, dim = -1)
        print(out.shape)
        return  torch.sum(out)

    else:
        raise Exception('Unrecognized loss functbion.')



def save_ckpt(network, epoch, name) :
    state = {
        'net': network.state_dict(),
        'epoch': epoch,
    }
    name = './checkpoint/tmp/' + name + 'best.pth'
    torch.save(state, name)
    print_rank0("Saved to ", name)

def save_target(avg_mem_trains_acc, dataloader, mode, method, act_eps):
    #save spike
    # Calculate the average spike train for the whole dataset
    for l_name in avg_mem_trains_acc.keys():
        dist.all_reduce(avg_mem_trains_acc[l_name])
        avg_mem_trains_acc[l_name] /= (len(dataloader) * dist.get_world_size())

    if dist.get_rank() == 0:
        # Save the dictionary of average spike trains to a file
        name =  "./spikedata/tmp/" +"avg_mem_" + mode + "_" + method + "_"  + str(act_eps)[:4] +'.pt'
        #print(avg_mem_trains_acc)
        torch.save(avg_mem_trains_acc, name)
        print_rank0("Saved target to ", name)

    return avg_mem_trains_acc

def train(network, trainloader, opti, epoch, network_config, layers_config, err, method = None, act_eps = None, tau_v= None):
    network.train()
    logging.info('\nEpoch: %d', epoch)
    train_loss = 0
    correct = 0
    total = 0
    n_steps = network_config['n_steps']
    n_class = network_config['n_class']


    if network_config['loss'] == "kernel":
        desired_spikes = set_target_signal(n_steps, network_config)
    else:
        desired_spikes = None

    avg_mem_trains_acc = {l.name: 0 for l in network.module.layers if l.type in ["conv", "linear"]}
    for _, (inputs, labels) in tqdm(enumerate(trainloader)):

        # Initialize an empty dictionary to accumulate the average spike train for each layer
        if len(inputs.shape) < 5:
            inputs = inputs.unsqueeze_(-1).repeat(1, 1, 1, 1, n_steps)
        # forward pass
        labels = labels.to(device)
        inputs = inputs.to(device)
        inputs.type(torch.float32)

        #set target
        targets = torch.zeros(labels.shape[0], n_class, 1, 1, n_steps).to(device) 
        if network_config['loss'] == "kernel":
            targets.zero_()
            for i, label in enumerate(labels):
                targets[i, label, ...] = desired_spikes

        #gen adv target
        if method == "clean":
            pass
        elif method == "fgm":
            inputs = fast_gradient_method(network, inputs, targets, act_eps, network_config)
        elif method == "pgd":
            inputs = projected_gradient_descent(network, inputs, targets, act_eps, network_config)

        # forward pass
        outputs, avg_mem_trains= network.forward(inputs, True)

        #Accumulate the average spike train for each layer
        if(avg_mem_trains!= {}):
            for l_name in avg_mem_trains_acc.keys():
                avg_mem_trains_acc[l_name] += avg_mem_trains[l_name]

        #cal loss
        loss = calculate_loss(outputs, labels, network_config, layers_config, err, desired_spikes)

        #backward
        opti.zero_grad()
        loss.backward()

        #clip grad
        clip_grad_norm_(network.module.get_parameters(), 1)
        opti.step()

        network.module.weight_clipper()

        spike_counts = torch.sum(outputs, dim=4).squeeze_(-1).squeeze_(-1).detach().cpu().numpy()
        predicted = np.argmax(spike_counts, axis=1)
        train_loss += torch.sum(loss).item()
        labels = labels.cpu().numpy()
        total += len(labels)
        correct += (predicted == labels).sum().item()


    total_accuracy = correct / total
    total_loss = train_loss / total
    if total_accuracy > glv.max_accuracy:
        glv.max_accuracy = total_accuracy
    if glv.min_loss > total_loss:
        glv.min_loss = total_loss
    return 100. * total_accuracy, total_loss, avg_mem_trains_acc

def test(network, testloader, epoch, network_config, method = None, act_eps = None):
    #test mode
    network.eval()
    
    correct = 0
    total = 0
    n_steps = network_config['n_steps']
    n_class = network_config['n_class']

    y_pred = []
    y_true = []
    adv_samples = []  # list to store adversarial samples    
    adv_labels = []  # List to accumulate labels of adversarial samples

    desired_spikes = set_target_signal(n_steps, network_config)
    
    avg_mem_trains_acc = {l.name: 0 for l in network.module.layers if l.type in ["conv", "linear"]}
    
    apgd = APGD(network, eps = act_eps, network_config = network_config)

    for _, (inputs, labels) in tqdm(enumerate(testloader)):
    
        # Initialize an empty dictionary to accumulate the average spike train for each layer

        if len(inputs.shape) < 5:
            inputs = inputs.unsqueeze_(-1).repeat(1, 1, 1, 1, n_steps)
        # forward pass
        labels = labels.to(device)
        inputs = inputs.to(device)

        #define target
        targets = torch.zeros(labels.shape[0], n_class, 1, 1, n_steps).to(device)
        targets.zero_()
        for i in range(len(labels)):
            targets[i, labels[i], ...] = desired_spikes


        #general adv sample
        if method == "clean": 
            adv_inputs = inputs                
        elif method =="fgm":
            adv_inputs = fast_gradient_method(network, inputs, targets, act_eps, network_config)
        elif method =="pgd":
            adv_inputs = projected_gradient_descent(network, inputs, targets, act_eps, network_config)
        elif method =="rfgm":
            adv_inputs = r_fgsm(network, inputs, targets, act_eps, act_eps/2 , network_config)
        elif method =="bim":
            adv_inputs = bim_attack(network, inputs, targets, act_eps, network_config)
        elif method =="apgd":
            adv_inputs = apgd(inputs, labels)
        elif method =="gn":
            adv_inputs = gaussian_noise_attack(inputs, act_eps, network_config)
        else:
            exit("wrong method")
        outputs , avg_mem_trains= network.forward(adv_inputs, False)

        if method != "clean":
            adv_samples.append(adv_inputs.cpu())  # add adversarial samples to the list
            adv_labels.append(labels.cpu())  # Add labels of the adversarial samples

        # Accumulate the average spike train for each layer
        if(avg_mem_trains!= {}):
            for l_name in avg_mem_trains_acc.keys():
                avg_mem_trains_acc[l_name] += avg_mem_trains[l_name]


        spike_counts = torch.sum(outputs, dim=4).squeeze_(-1).squeeze_(-1).detach().cpu().numpy()
        predicted = np.argmax(spike_counts, axis=1)
        labels = labels.cpu().numpy()
        y_pred.append(predicted)
        y_true.append(labels)
        total += len(labels)
        correct += (predicted == labels).sum().item()

    # This part will gather the results from all processes.
    # First, let's gather the total counts from each process:
    total_tensor = torch.tensor([total]).to(device)
    dist.all_reduce(total_tensor)
    total = total_tensor.item()

    # Do the same for correct counts:
    correct_tensor = torch.tensor([correct]).to(device)
    dist.all_reduce(correct_tensor)
    correct = correct_tensor.item()

    # Compute the global accuracy:
    test_accuracy = correct / total

    if dist.get_rank() == 0:  # If this is the main process, update the best accuracy
        if test_accuracy > glv.best_acc:
            glv.best_acc = test_accuracy
            glv.best_epoch = epoch

    if (method != "clean" and network_config["save_atk_data"]):
        # Before the reduction phase, concatenate the samples locally on each process.
        all_adv_samples = torch.cat(adv_samples, dim=0).cuda()
        all_adv_labels = torch.cat(adv_labels, dim=0).cuda()

        # Prepare storage for the gathered data on the main process
        sample_gather_list = [torch.zeros_like(all_adv_samples) for _ in range(dist.get_world_size())] if dist.get_rank() == 0 else None
        label_gather_list = [torch.zeros_like(all_adv_labels) for _ in range(dist.get_world_size())] if dist.get_rank() == 0 else None
        
        # Every process participates in the gather, but only rank 0 actually collects all the data
        dist.gather(all_adv_samples, sample_gather_list, dst=0)
        dist.gather(all_adv_labels, label_gather_list, dst=0)

        # Only the main process (rank 0) concatenates and saves the data
        if dist.get_rank() == 0 :
            # Concatenate all the adversarial samples from all GPUs
            combined_adv_samples = torch.cat(sample_gather_list, dim=0)
            combined_adv_labels = torch.cat(label_gather_list, dim=0)
            torch.save(combined_adv_samples, f"./adv_samples/{method}_eps_{act_eps}.pt")
            torch.save(combined_adv_labels, f"./adv_labels/{method}_eps_{act_eps}.pt")

    acc = 100. * test_accuracy

    return acc, avg_mem_trains_acc


def advtest2(net, test_loader, params, test_only=True, ckpt=None, blackbox=False):
    atk_strength = params['ATTACK']['test']
    #attack_types = ['fgm', 'rfgm', 'pgd', 'bim', 'gn']
    attack_types = ['fgm']
    tmp = params['Network']["dynamic_v"]
    params['Network']["dynamic_v"] = True

    results = []

    assert ckpt is None, "Checkpoint path should not be provided if not in test_only mode"
    print_rank0("Using current model!")

    # Clean test
    clean_acc, _ = test(net, test_loader, 0, params['Network'], 'clean', 0)
    print_rank0(f"Clean test,  acc = {clean_acc}")
    results.append(clean_acc)

    # Attack test
    act_eps = eval(atk_strength[0]) / params["Network"]["std"]
    for atk in attack_types:
        acc, _ = test(net, test_loader, 0, params['Network'], atk, act_eps)
        results.append(acc)
        print_rank0(f"Acc test under {atk}, Ori_eps = {atk_strength}, Act_eps = {round(act_eps, 3)}, acc = {acc}")
    # Print results
    params['Network']["dynamic_v"] = tmp
    return results



def log_and_test_attack(net, test_loader, attack_type, ori_eps, act_eps, params, results):
    acc = test(net, test_loader, 0, params['Network'], attack_type, act_eps)
    results[attack_type].append(acc)
    print_rank0(f"Acc test under {attack_type}, Ori_eps = {ori_eps}, Act_eps = {round(act_eps, 3)}, acc = {acc}")
    return acc

def advtest(net, test_loader, params, test_only=True, ckpt=None, blackbox=False):
    atk_strength = params['ATTACK']['strength']
    attack_types = params['ATTACK']['attack_types']
    print_rank0(f"Attack types: {attack_types}")

    results = {atk: [] for atk in attack_types}

    if test_only:
        assert ckpt is not None, "Checkpoint path must be provided in test_only mode"
        checkpoint = torch.load(ckpt, map_location=device)
        net.load_state_dict(checkpoint['net'], strict = False)
        print_rank0(f"Adv Testing, using {ckpt}")
    else:
        assert ckpt is None, "Checkpoint path should not be provided if not in test_only mode"
        print_rank0("Using current model!")

    # Clean test
    clean_acc, avg_mem_trains_acc = test(net, test_loader, 0, params['Network'], 'clean', 0)
    print_rank0(f"Clean test,  acc = {clean_acc}")
    if(bool(params['Network']["save_target"])== True):
        save_target(avg_mem_trains_acc, test_loader, "test", "clean", 0)

    for atk in attack_types:
        results[atk].append(clean_acc)

    metrics = {f'{atk}_acc': clean_acc for atk in attack_types}
    wandb_log_rank0(metrics, step=1000)
    count = 1000

    # Attack test
    for ori_eps in atk_strength:
        act_eps = eval(ori_eps) / params["Network"]["std"]
        count += 1
        for atk in attack_types:
            if blackbox:
                # Load adversarial samples and labels
                adv_samples = torch.load(f"./adv_samples/{atk}_eps_{act_eps}.pt",map_location=device)
                adv_labels = torch.load(f"./adv_labels/{atk}_eps_{act_eps}.pt",map_location=device)
                adv_dataset = TensorDataset(adv_samples, adv_labels)
                sampler = DistributedSampler(adv_dataset)
                test_loader = DataLoader(adv_dataset, batch_size=test_loader.batch_size, shuffle=False, sampler=sampler)
                acc, avg_mem_trains_acc = test(net, test_loader, 0, params['Network'], 'clean', 0)

            else:
                acc, avg_mem_trains_acc = test(net, test_loader, 0, params['Network'], atk, act_eps)

            results[atk].append(acc)
            print_rank0(f"Acc test under {atk}, Ori_eps = {ori_eps}, Act_eps = {round(act_eps, 3)}, acc = {acc}")

            metrics[f'{atk}_acc'] = acc
            wandb_log_rank0(metrics, step=count)  # Assuming step increases with act_eps

            if(bool(params['Network']["save_target"])== True):
                save_target(avg_mem_trains_acc, test_loader, "test", atk, act_eps)
    # Print results
    for atk in attack_types:
        print_rank0(f"{atk} result: {results[atk]}")

    return results

def advtrain(net, train_loader, test_loader, train_sampler,
                params, mode):
    
    error = loss_f.SpikeLoss(params['Network']).to(device)
    net_parameters = net.module.get_parameters()
    learning_rate = params['Network']['lr']
    epochs = params['Network']['epochs']
    param_dict = {i: name for i, (name, _) in enumerate(net.named_parameters())}

    tauv_index = params['ATTACK']['tauv']
    weight_dacay = float(params['Network']['weight_dacay'])
 
    if mode != "clean":
        method = params['ATTACK']['ft_method']
        ori_eps = params['ATTACK']['train'][0]
        act_eps = eval(ori_eps)/params["Network"]["std"]
        print_rank0("Adv train under", method, "Ori_eps = ", ori_eps, "Act_eps = ", round(act_eps, 3))
        finetune_layer = params['ATTACK']['finetune_layer']
        if finetune_layer:
            print_rank0("Only finetune layer:", finetune_layer)
            for i in finetune_layer:
                print_rank0(param_dict[i])
            optimizer = torch.optim.AdamW([list(net.parameters())[i] for i in finetune_layer],
                                           lr=learning_rate,
                                             betas=(0.9, 0.999),
                                             weight_decay=weight_dacay)    
        else:
            print_rank0("Finetune all") 
            optimizer = torch.optim.AdamW(net_parameters,
                                           lr=learning_rate,
                                            betas=(0.9, 0.999),
                                            weight_decay=weight_dacay)    
    else:
        method = "clean"
        act_eps = 0
        optimizer = torch.optim.AdamW(net_parameters, 
                                      lr=learning_rate,
                                        betas=(0.9, 0.999),
                                        weight_decay=weight_dacay)
        print_rank0("Clean Train all!")
    

    lr_scheduler = CosineAnnealingLR(optimizer, T_max=epochs, verbose= True)
    best_acc = 0
    for epoch in tqdm(range(epochs)):
        print_rank0(f"{mode} training, epoch = ", epoch)
        train_sampler.set_epoch(epoch)
        train_acc, train_loss, avg_mem_trains_acc = train(net, train_loader, optimizer, epoch, 
                                      params['Network'], params['Layers'], error, 
                                      method, act_eps, tauv_index)
        lr_scheduler.step()

        print_rank0("Train Accuracy: %.3f, Train Loss: %.3f " % (train_acc,train_loss))
        # test_acc, _ = test(net, test_loader, epoch, params['Network'],
        #                 "clean", 0)
        adv_acc = advtest2(net, test_loader, params, test_only=False, ckpt=None, blackbox=False)
        test_acc = adv_acc[0]
        print_rank0("Test Accuracy" + str(adv_acc))
        metrics = {'train_acc':train_acc, 'train_loss':train_loss, 'test_acc':adv_acc[0], 'fgm_test_acc': adv_acc[1]}
        
        display_layer = params['ATTACK']['display_layer']
        if (params["Network"]["model"] == "ALIF" and display_layer):
            tau_dict = print_tau(net, display_layer)
            metrics.update(tau_dict)
        wandb_log_rank0(metrics)
        
        # save ckpt
        ckpt =  params['DEFAULT']['dataset']+ params['Network']["model"] + params['Network']["rule"] + mode + "_" + method + "_"
        if(test_acc > best_acc):
            best_acc = test_acc
            save_ckpt(net, epoch, ckpt)
        if(bool(params['Network']["save_target"])== True):
            save_target(avg_mem_trains_acc, train_loader, "train", method, act_eps)
        
        save_ckpt(net, epoch, ckpt + str(epoch))

        dist.barrier()  # Ensure all processes have finished clean test
    
    params['Network']["dynamic_v"] = True
    if mode != "clean":
        advtest(net, test_loader, params, False)
    else:
        advtest(net, test_loader, params, True, './checkpoint/tmp/' + ckpt + 'best.pth')
