from statistics import median, mean
import torch
from statistics import mean, median
import numpy as np
from omegaconf import DictConfig
from optimizer import Momentum, cos_scheduler, margin_loss, cross_entorpy_loss, Adam, step_lr_scheduler
from typing import Callable
from typing import Tuple
from torch.nn import functional as F
from utils.general_utils import *
from torch.nn import DataParallel
import os 
import shutil


loss_fn = margin_loss
SEED=0
torch.random.initial_seed()
torch.random.manual_seed(0)
np.random.seed(SEED)
torch.cuda.manual_seed_all(SEED)
norm_theshold=10


def image_save(adv_image, image, orig_label, target_label, dataset, targeted, index, grouping_strategy):
    """save image to local

    Args:
        adv_image (tensor): the adversarial image
        image (tensor): the orignal image
        orig_label (int): the original label
        target_label (int): the target label for target attack, the original label for untarget attack
        dataset (string): "cifar, mnist, imagenet"
        targeted (string): if targeted
        index (int): the index of current image
        grouping_strategy (string): standard
    """
    adv_image = img_transform(adv_image.cpu().numpy())
    ori_image = img_transform(image.cpu().numpy())
    orgi_ = orig_label.cpu()
    img_save(ori_image, index,"origin", grouping_strategy, dataset, orgi_, target_label) 
    img_save(adv_image, index, targeted + '_adv_', grouping_strategy, dataset, orgi_, target_label)
    img_save(np.abs(adv_image - ori_image), index, targeted+"_delta", grouping_strategy, dataset, orgi_, target_label)
    

def esal(images: np.ndarray, labels: np.ndarray, model: torch.nn.Module, cfg: DictConfig):
    """perform esal_0+inf attack on target model

    Args:
        images (np.ndarray): input(clean) images
        labels (np.ndarray): corresponding labels
        model (torch.nn.Module): target model to attack
        cfg (DictConfig): setup for algorithm
        loss_fn (Callable): take in score and target label and return loss

    Return:
        results (Dict): the required results
        {'query' : list of all image queries,
         'acc'   : ASR,
         'l0'    : the l0 norm,
         'l2'    : the l2 norm,
         'linf'  : the linf norm,
         'psnr'  : psnr,
         'ssim'  : ssim}
    """

    
    # init hyper parameters
    log_iters = cfg.general_setup.log_iters
    epsilon = cfg.algorithm.epsilon  # correspond to epsilon in the paper
    samples_per_draw = cfg.algorithm.samples_per_draw
    batch_size = cfg.algorithm.batch_size  # max batch size to evaluate the output
    max_iters = cfg.algorithm.max_iters
    max_query = cfg.algorithm.max_query
    sigma = cfg.algorithm.sigma  # refers to sigma in the paper
    
    image_size = torch.tensor(cfg.algorithm.image_size)  # image size
    num_labels = cfg.algorithm.num_labels
    plateu_length = cfg.algorithm.plateu_length
    filtersize = torch.tensor(cfg.algorithm.filterSize)
    stride = cfg.algorithm.stride
    channels = cfg.algorithm.num_channels
    grouping_strategy = cfg.algorithm.grouping_strategy
    dataset = cfg.algorithm.dataset_name
    momentum = cfg.optimizer.momentum
    optimizer = cfg.optimizer
    max_learning_rate = cfg.optimizer.max_lr
    min_learning_rate = cfg.optimizer.min_lr
    if_overlap = cfg.algorithm.if_overlap
    targeted = cfg.algorithm.targeted
    drop_epoch = cfg.optimizer.drop_epoch
    is_img_save = cfg.general_setup.save_image
    model_name = cfg.algorithm.model_name
    perturb_rate = cfg.algorithm.perturb_rate

    d = channels*image_size*image_size
    k_init = cfg.algorithm.k  # top k indices, used in the paper, also refer to s in the paper
    # k_init = round(1.0*d.item()/filtersize.item()/filtersize.item()/channels*perturb_rate)

    
    load_pre_groups = cfg.general_setup.load_pre_groups # Whether to load pregroup
    G = 0
    if dataset == 'imagenet':
        max_per_draw = 128
        max_perturb = 90
        if model_name == 'inceptionv3':
            image_size = torch.tensor(299)
        elif model_name == 'VT':
            image_size = torch.tensor(224)
    else:
        if cfg.algorithm.if_overlap == "overlapping":
            max_perturb = 18 # cifar mnist untarget: 16 target: 20
        else:
            max_perturb = 25 # cifar mnist untarget: 16 target: 20
        max_per_draw = 64
    
    
    # 切换为测试模式
    model.eval()
    parallel_device_num = cfg.general_setup.parallel_device_num
    device_num = torch.cuda.device_count()
    parallel_device_num = min(device_num, parallel_device_num)
    device_ids = list(np.arange(parallel_device_num))
    
    # model = model.to(device_ids[0])
    if parallel_device_num > 1:
        print('using data parallel on', device_ids)
        model = model.to(device_ids[0])
        model = DataParallel(model, device_ids)
    else:
        model = model.to(device_ids[0])
        print('using single gpu')


    # test model acc, get correct index and attack correct images
    # print(f'image shape: {images.shape}')
    # print(f'label shape: {labels.shape}')

    images = torch.tensor(images).to(device_ids[0]).float()
    labels = torch.tensor(labels).to(device_ids[0])
    # if dataset == 'imagenet':
    #     labels = torch.nonzero(labels)[:,-1]-1
        #print(labels)

    #print(labels[0])
    if targeted == 'targeted':
        print(f"It performs targeted attack on {dataset} dataset!")
        target_class = torch.tensor(
            [pseudorandom_target(index, num_labels, orig_label)
             for index, orig_label in enumerate(labels)], dtype=torch.int64
        ).to(device_ids[0])
    else:
        print(f"It performs untargeted attack on {dataset} dataset!")
        target_class = labels

    host = torch.device(device_ids[0])

    # validation on model
    with torch.no_grad():
        score = model(images)
        pred = torch.argmax(score, dim=-1)
        torch.cuda.empty_cache()
        # remove_id = torch.where(pred != labels)
        # remove_error(remove_id)
        correct_idx = pred == labels
        
        acc = torch.mean(correct_idx.float())
    print(f'model acc: {acc.cpu()}')
    # assert acc > 0.65, 'we s/hould a/ttack a well trained model'

    # get correct samples to attack
    images = images[correct_idx]
    labels = labels[correct_idx]
    target_class = target_class[correct_idx]
    print(f"Corrrectly classified images are {images.shape[0]}.")

    # compute the true gradient
    def compute_loss(evaluate_img: torch.Tensor, target_labels: torch.Tensor,):
        score = model(evaluate_img)
        loss = margin_loss(score, target_labels,targeted)
        return torch.squeeze(loss, 0)

    # @torch.no_grad()
    def get_grad_estimation(evaluate_img: torch.Tensor, target_labels: torch.Tensor, sample_per_draw: int,
                            batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """get gradient estimation on evaluate_img using nes

        Args:
            evaluate_img (torch.Tensor): evaluate image
            target_label (torch.Tensor): replicated target labels for computing loss
            sample_per_draw (int): Estimate the total number of samples required for a gradient
            batch_size (int): a batch size of samples to estimation gradient

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: gradient, loss
        """
        total_batch_num = sample_per_draw // batch_size
        grads = []
        total_loss = []
        # with torch.no_grad():
        for _ in range(total_batch_num):
            noise = torch.normal(mean=0.0, std=1.0, size=(batch_size // 2,) + evaluate_img.shape[1:], device=host)
            noise = torch.concat([noise, -noise], dim=0)  # generate + delta and - delta for evaluation
            evaluate_imgs = evaluate_img + noise * sigma
            score = model(evaluate_imgs)
            loss = loss_fn(score, target_labels, targeted)
            loss = loss.reshape(-1, 1, 1, 1).repeat(evaluate_img.shape)
            grads.append(torch.mean(loss * noise / sigma, dim=0, keepdim=True))
            total_loss.append(torch.mean(loss))
            torch.cuda.empty_cache()

        grads = torch.mean(torch.concat(grads), dim=0, keepdim=True)
        if torch.norm(grads,p=2) < 0.1 * norm_theshold:
            grads *= norm_theshold
        if torch.norm(grads,p=2) > norm_theshold:
            grads = grads/torch.norm(grads,p=2)*norm_theshold
        total_loss = torch.mean(torch.tensor(total_loss, device=host), dim=0)
        return grads, total_loss
    

    def select_optimal_group(fh: torch.Tensor, masks: torch.Tensor):
        """Select the group with the lowest h value

        Args:
            fh (torch.Tensor): flattened h, h - h_{G_optimal}
            last_optimal_group (set): As the name implies
            index (list): Groups

        Returns:
            torch.Tensor : onehot_optimal_group
        """

        Group_h = torch.mm(masks.float(), fh.reshape(d, 1)) # [G,1]
        min_h_group_num = torch.argmin(Group_h)
        optimal_group = masks[min_h_group_num].clone()
        masks[min_h_group_num]=0

        return optimal_group

    def greedy_project(h: torch.Tensor, delta: torch.Tensor, masks: torch.Tensor, groups_num) -> torch.Tensor:
        """ Select k groups using greedy projection, i.e. select the first k groups with the lowest h value

        Args:
            h (torch.Tensor): the standard value of greedy selecting
            delta (torch.Tensor): the original perturbations
            masks (torch.Tensor): onehot_groups_index
            index: groups_index

        Returns:
            torch.Tensor: the perturbation
        """
        u = delta

        flatten_h = torch.flatten(h) # [3072]
        flatten_u = torch.flatten(u) # [3072]
        flatten_v = torch.zeros_like(flatten_u)

        for i in range(groups_num):
            onehot_optimal_group = select_optimal_group(flatten_h, masks)
            flatten_v += onehot_optimal_group * flatten_u
            flatten_h *= onehot_optimal_group.logical_not() 
            flatten_u *= onehot_optimal_group.logical_not() 
        return flatten_v.resize(channels,image_size,image_size,)

    def standard_grouping():
        """ grouping with fixed windows

        Returns:
            index (Tensor) : groups with real index [Groups_nums, Each_Group_len]
            masks (Tensor) : [Groups_nums, d]  
        """
        R = torch.floor((image_size - filtersize) / stride) + 1
        R = R.type(torch.int32)
        C = R
        index = torch.zeros([R*C, filtersize * filtersize * channels],dtype=torch.int32)
        masks = []
        tmpidx = 0
        for r in range(R):
            plus1 = r * stride * image_size * channels
            for c in range(C):
                index_ = []
                for i in range(filtersize):
                    plus2 = c * stride * channels + i * image_size * channels + plus1
                    index_.append(torch.arange(plus2, plus2 + filtersize * channels))
                
                li = []
                for x in range(d):
                    for j in range(filtersize):
                        if x in index_[j]:
                            li.append(1)
                            break
                    else:
                        li.append(0)
                masks.append(li)

                # masks.append([ 1 if x in index_[0] or x in index_[1] else 0 for x in range(d) ])
                index[tmpidx] = torch.cat(index_, dim=-1)
                tmpidx += 1
        # index = torch.tile(index, (batch_size,1,1))
        return torch.tensor(masks)

    def attack(target_img: torch.Tensor, target_label: torch.Tensor, original_label: torch.Tensor, 
               samples_per_draw, batch_size, masks : torch.Tensor, max_learning_rate):
        """perform easl attack on target img 

        Args:
            target_img (torch.Tensor): target attack img
            target_label (torch.Tensor): label for target attack
            original_label (torch.Tensor): original label for the image
            samples_per_draw (int): times of estimating gradient in each attacking
            batch_size (int):  The parallel dimensions of a picture
            index: Initial group index
        """
        with torch.no_grad():
            # initialization
            target_img = target_img.reshape((1,) + target_img.shape)
            gradient_transformer = Momentum(variabels=target_img, momentum=momentum)
            admm_transformer = Adam(target_img)
            scheduler = cos_scheduler(num_iter=max_iters, **optimizer)
            scheduler = step_lr_scheduler(max_lr=optimizer.max_lr, min_lr=optimizer.min_lr,drop_epoch=drop_epoch)
            num_queries = 0
            delta = torch.rand(channels,image_size,image_size).to(device_ids[0])
            # delta = torch.tensor(0).to(device_ids[0])
            flag = False
            lower_bond = torch.maximum(torch.tensor(-epsilon).to(device_ids[0]), -0.5 -target_img)
            uppder_bond = torch.minimum(torch.tensor(epsilon).to(device_ids[0]), 0.5- target_img)
            adv_image = target_img.detach().clone()
            # target_labels = F.one_hot(target_label, num_classes=num_labels).repeat(batch_size,
            #                                                                        1)  # create one-hot target labels as (batch_size, num_class)
            
            last_ls = []
            k = k_init
            k_hat = k_init
            last_query = max_query

            # for iter in range(max_iters):
            for iters in range(max_query):

                # check if we can make an early stopping 
                pred = torch.argmax(model(adv_image, ))
                if is_sucessful(target_label, pred):
                    flag = True
                    print(f'[succ] Iter: {iters}, groups: {k}, query: {num_queries}, loss: {loss.cpu():.3f}, l0:{l0_norm.cpu():.0f}, l2:{l2_norm.cpu():.1f}, linf:{linf_norm.cpu():.2f}, prediction: {pred}, target_label:{target_label}')
                    break
                
                target_labels = F.one_hot(target_label, num_classes=num_labels).repeat(batch_size,
                                                                                       1)  # create one-hot target labels as (batch_size, num_class)
                
                # estimate the gradient
                grads, loss = get_grad_estimation(evaluate_img=adv_image,
                                                  target_labels=target_labels,
                                                  sample_per_draw=samples_per_draw,
                                                  batch_size=batch_size)

                # # compute the true gradient
                # grads, loss = torch.func.grad_and_value(compute_loss)(adv_image, target_labels)

                last_ls.append(loss)
                last_ls = last_ls[-plateu_length:]
                if last_ls[-1] >= last_ls[0] and len(last_ls) == plateu_length:
                    samples_per_draw += batch_size
                    samples_per_draw = min(samples_per_draw, max_per_draw)
                    # print("alter the sample and learning rate.")
                    batch_size = samples_per_draw
                    k += round(k_hat * 0.9)
                    k = min(k, max_perturb)
                    k_hat *= 0.9
                    if max_learning_rate > cfg.optimizer.min_lr:
                        max_learning_rate = max(max_learning_rate * 0.9 , cfg.optimizer.min_lr)
                    
                    # delta += torch.rand(channels,image_size,image_size).to(device_ids[0])
                    last_ls = []


                grads = admm_transformer.apply_gradient(grad=grads)
                # lr = next(scheduler)
                lr=max_learning_rate
                delta = delta - lr * grads
                pro_delta = torch.clip(delta, lower_bond, uppder_bond)  # clip the delta to satisfy l_inf norm

                h = pro_delta ** 2 - 2 * pro_delta * delta # shape: [1, image.shape]

                # h = delta
                unclip_delta = greedy_project(h, delta, masks.clone(), k)
                
                # pro_delta = pro_delta.flatten()
                # flatten_h = h.flatten()
                # min_k_idx = torch.topk(flatten_h, dim=0, k=k, largest=True).indices
                # delta_k = torch.zeros_like(pro_delta)
                # delta_k[min_k_idx] = pro_delta[min_k_idx]
                # delta = delta_k.reshape_as(target_img)
                
                delta = torch.clip(unclip_delta, lower_bond, uppder_bond)  # clip the delta to satisfy l_inf norm
                adv_image = torch.clip(target_img + delta, -0.5, 0.5)



                l0_norm = torch.sum((delta != 0).float())
                l2_norm = torch.norm(delta)
                linf_norm = torch.max(torch.abs(delta))
                num_queries += samples_per_draw+1


                last_query -= samples_per_draw+1
                if last_query -samples_per_draw-1 <0:
                    break
                if iters+1 % log_iters == 0:
                    print('attack iter {}, loss: {:.5f}, query: {}, l0 norm:{:.5f}, l2 norm: {:.5f}, lr:{:.5f}, prediction: {}, target_label:{}'.format(
                        iters, loss.cpu(), samples_per_draw, l0_norm, l2_norm.cpu(), lr, pred.cpu(), target_label))
            else:
                print("Fail Attack!")
                pass

            if targeted == 'untargeted':
                return adv_image, flag, num_queries, l0_norm.cpu(), l2_norm.cpu(), linf_norm.cpu(), pred.cpu()
            else:
                return adv_image, flag, num_queries, l0_norm.cpu(), l2_norm.cpu(), linf_norm.cpu(), target_label.cpu()
            
    def is_sucessful(target_label, pred):
        return (targeted == 'untargeted' and pred != target_label) or (targeted == 'targeted' and pred == target_label)

    num_queries_list = []
    l0_norm_list = []
    l2_norm_list = []
    linf_norm_list = []
    psnr_list = []
    ssim_list = []
    result = {}
    acc_count = 0

    if grouping_strategy == 'standard':
        # Standard grouping
        print("Standard Grouping!")
        if load_pre_groups:
            if if_overlap == "nonoverlapping":
                masks = torch.load(f"Group/{dataset}/onehot_index_{dataset}_standard_{if_overlap}_{filtersize}.pth")
            else :
                masks = torch.load(f"Group/{dataset}/onehot_index_{dataset}_standard_{if_overlap}_{filtersize}{stride}.pth")
            
            # index = torch.load(f"/home/yym/Documents/YYM/2/our/NES/Group/index_{dataset}_standard_{if_overlap}.pth")
        else:    
            masks = standard_grouping()
            if model_name == 'VT':
                torch.save(masks, f"Group/{model_name}/onehot_index_{dataset}_standard_{if_overlap}.pth")
            else:
                torch.save(masks, f"Group/{dataset}/onehot_index_{dataset}_standard_{if_overlap}.pth")
        masks = masks.reshape(-1, image_size, image_size, channels).transpose(1,3).flatten(1)
        print("Grouping Completion!")
        print("masks shape", masks.shape)

        # cv2.imwrite(os.path.join("test.png"), masks.cpu().numpy().reshape(1, image_size, image_size).transpose(1,2,0)*200)
    # elif grouping_strategy == "kmeans":
    #     G = 100
    #     print(f"This image is divided into {G} groups.")
    assert load_pre_groups == True
    
    i = 0
    index_fail = []
    for image, orig_label, target_label in zip(images, labels, target_class):
        print("No. ",i)
        i+=1

        adv_image, flag, num_queries, l0_norm, l2_norm, linf_norm, target_label = attack(target_img=image.to(device_ids[0]),
                                                                    target_label=target_label.to(device_ids[0]),
                                                                    original_label=orig_label.to(device_ids[0]),
                                                                    samples_per_draw=samples_per_draw,
                                                                    batch_size=batch_size,
                                                                    masks=masks.to(device_ids[0]),
                                                                    max_learning_rate=max_learning_rate
                                                                    )

        if flag:
            orig_img = img_transform(image.cpu().numpy())
            adv_img = img_transform(adv_image[0].cpu().numpy())
            psnr_list.append(calculate_psnr(orig_img,adv_img, dataset))
            ssim_list.append(calculate_ssim(orig_img,adv_img, dataset))
            l0_norm_list.append(l0_norm)
            l2_norm_list.append(l2_norm)
            linf_norm_list.append(linf_norm)
            acc_count += 1
            
            if is_img_save:
                image_save(adv_image[0], image, orig_label, target_label, dataset, targeted, i, grouping_strategy)

        else:
            index_fail.append(i)

        num_queries_list.append(num_queries)


            

    acc = acc_count / torch.sum(correct_idx.float())
    print("fail_index:\n",index_fail)
    print("acc_count:",acc_count, "len(list)", torch.sum(correct_idx.float()))
    result = {key: value for key, value in [
        ('query', num_queries_list), 
        ('acc', acc),
        ('l0', torch.mean(torch.tensor(l0_norm_list))),
        ('l2', torch.mean(torch.tensor(l2_norm_list))),
        ('linf', torch.mean(torch.tensor(linf_norm_list))), 
        ('psnr', psnr_list), 
        ('ssim', ssim_list)
        ]}
    

    return result
