import os
import pickle

import torch
from torch import nn

from utils import *
from config import opt
from defense import poison_fun_tensor

def query_victim(data, victim, cip_ood_counter=None):
    if opt.victim_ood_dataset: # cip defense
        model_name = f'victim_{opt.victim_model}_{opt.victim_dataset}_ood-{opt.victim_ood_dataset}'
        cip_ckpt_dir = os.path.join(opt.work_dir, 'checkpoints/cip_ckpt')
        cip_energy_path = os.path.join(cip_ckpt_dir, 'cip_energy_dict.pkl')
        with open(cip_energy_path, 'rb') as fin:
            cip_energy_dict = pickle.load(fin)
        FPR_energy = cip_energy_dict[model_name]['FPR_energy']
        open_set_energy = cip_energy_dict[model_name]['open_set_energy']
        probs = poison_fun_tensor(model=victim, image_tensor=data, transform_test=None,
                                  attack=opt.source, ratio=17, dataset=opt.victim_dataset,
                                  FPR=FPR_energy, Open_set_energy=open_set_energy,
                                  trigger_path=None, ood_counter=cip_ood_counter)
    else: # normal prediction
        softmax = nn.Softmax(dim=1)
        with torch.no_grad():
            outputs = victim(data)
        if opt.victim_wm_dataset:
            probs = softmax(outputs[-1])
        else:
            probs = softmax(outputs)

        # perturb victim return
        if opt.victim_return_type == 'label':
            probs = probs.max(1)[1]
        elif 'top' in opt.victim_return_type:
            k = int(opt.victim_return_type.split('-')[-1])
            topk_scores, topk_indices = probs.topk(k, dim=1)
            top_probs = torch.zeros_like(probs)
            top_probs.scatter_(1, topk_indices, topk_scores)
            probs = top_probs
        elif 'round' in opt.victim_return_type:
            k = int(opt.victim_return_type.split('-')[-1])
            probs = torch.round(probs*10**k)/10**k
        elif 'noise' in opt.victim_return_type:
            std = float(opt.victim_return_type.split('-')[-1])
            noise = torch.normal(0, std, probs.shape)
            probs = probs + noise.cuda()

    return probs