# -- coding: utf-8 --**

"""
Created on 2023/12/1

@author: Ruoyu Chen
"""

import argparse

import os
import os
import json
import cv2
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image

from tqdm import tqdm

from sklearn import metrics
from torchvision import transforms
import torch

import csv
from transformers import ViTForImageClassification



def read_csv_to_dict(file_path):
    result_dict = {}
    with open(file_path, mode='r', newline='', encoding='utf-8') as csvfile:
        csvreader = csv.reader(csvfile)
        for row in csvreader:
            if not row:
                continue  # 跳过空行
            key = row[0]  # 第一列作为键
            if len(row) > 1:
                # 将剩余的列转换为整数列表
                values = [int(item) for item in row[1:] if item]
            else:
                values = None  # 如果没有其他列，则值为None
            result_dict[key] = values
    return result_dict


def preprocess(image):
    image = Image.fromarray(image, 'RGB')
    transform = transforms.Compose([
        # transforms.Resize(256),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                      std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).cuda()


def generate_masks_torch(patch_size):
    # 计算每个小块的宽度和高度
    patch_width = 224 // patch_size
    patch_height = 224 // patch_size

    # 初始化一个张量来存储所有掩码
    masks = torch.zeros((patch_size * patch_size, 3, 224, 224), dtype=torch.uint8)

    # 生成掩码
    for i in range(patch_size):
        for j in range(patch_size):
            # 计算当前小块的起始和结束位置
            start_x = j * patch_width
            start_y = i * patch_height
            end_x = start_x + patch_width
            end_y = start_y + patch_height

            # 创建掩码
            mask = torch.zeros((3, 224, 224), dtype=torch.uint8)
            mask[:, start_y:end_y, start_x:end_x] = 1

            # 将掩码存储到 masks 张量中
            masks[i * patch_size + j] = mask

    return masks


# model = ViTForImageClassification.from_pretrained('pretrained_model',
#                                                       subfolder='vit-base-patch16-224',
#                                                       ignore_mismatched_sizes=True)
#
# # state_dict = torch.load('checkpoint/b16_224/epoch10-baseline/model_best.pth.tar')['state_dict']
# # state_dict = torch.load('checkpoint/b16_224/eclip-epoch10-cgc-lambda0.5/model_best.pth.tar')['state_dict']
#
# # folder = 'ab-lambda/epoch10-lambda1.0-reg_freq20/'
#
# # folder = 'conf-epoch10-lambda0.5-reg_freq20/'
# # folder = 'conf+cons-epoch10-lambda0.5-reg_freq20/'
# # folder = 'colla-epoch10-lambda0.5-reg_freq20/'
# folder = 'colla+conf-epoch10-lambda0.5-reg_freq20/'
#
# state_dict = torch.load('checkpoint/lima/b16_224/'+folder+'model_best.pth.tar')['state_dict']


model = ViTForImageClassification.from_pretrained('pretrained_model',
                                                      subfolder='vit-large-patch16-224',
                                                      ignore_mismatched_sizes=True)
# state_dict = torch.load('checkpoint/l16_224/eclip-epoch10-baseline-first/model_best.pth.tar')['state_dict']
# state_dict = torch.load('checkpoint/l16_224/eclip-epoch10-cgc-lambda0.5/model_best.pth.tar')['state_dict']
state_dict = torch.load('checkpoint/lima/l16_224/three_loss-epoch10-lambda0.5-reg_freq20/model_best.pth.tar')['state_dict']
# state_dict = torch.load('checkpoint/lima/l16_224/two_loss-epoch10-lambda0.5-reg_freq20/model_best.pth.tar')['state_dict']
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}


state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model = model.cuda()


# result_path = 'result_search/b16_224/baseline/'  # todo
# result_path = 'result_search/b16_224/cgc/'
# result_path = 'result_search/' + folder

# result_path = 'result_search/l16-224/baseline/'
# result_path = 'result_search/l16-224/cgc/'
result_path = 'result_search/l16-224/three_loss/'

dict = read_csv_to_dict(result_path+'result_rest.csv')


image_root = "/mnt/huawei/jiaoxh/data/ImageNet100/val/"

print('   ******  ', result_path)

def main():

    insertion_aucs = []
    deletion_aucs = []

    insertion_area = np.arange(0, 37) / 36
    deletion_area = 1 - insertion_area

    for image_name in tqdm(dict.keys()):

        image = cv2.imread(image_root + image_name)
        image = cv2.resize(image, (224, 224))

        image_tensor = preprocess(image)
        patches_tensor = generate_masks_torch(6).cuda()
        element_sets_V = patches_tensor * image_tensor

        with torch.no_grad():
            image_tensor_norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image_tensor)
            output = model(image_tensor_norm.unsqueeze(0))
            target_label = output.logits.cpu().numpy().argmax()

        selected_batch = element_sets_V[dict[image_name][0]].unsqueeze(0)
        for i in range(2, len(dict[image_name])+1):
            selected_element = element_sets_V[np.array(dict[image_name][:i])].sum(0).unsqueeze(0)
            selected_batch = torch.cat((selected_batch, selected_element), dim=0)

        reverse_selected_batch = image_tensor - selected_batch

        selected_batch = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(selected_batch)
        reverse_selected_batch = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(reverse_selected_batch)

        with torch.no_grad():
            predicted_scores = torch.softmax(model(selected_batch).logits, dim=-1)
            consistency_scores = predicted_scores[:, target_label]

            predicted_scores = torch.softmax(model(reverse_selected_batch).logits, dim=-1)
            collaboration_scores = 1 - predicted_scores[:, target_label]

        insertion_score = consistency_scores.cpu().numpy().tolist()
        deletion_score = collaboration_scores.cpu().numpy().tolist()

        with open(result_path + 'consistency_score_full.csv', 'a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([image_name] + insertion_score)

        with open(result_path + 'collaboration_score_full.csv', 'a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([image_name] + deletion_score)

            # submodular_image_set = np.load(npy_file_path)


            # insertion_area = []
            # deletion_area = []
            # image = submodular_image_set.sum(0)

            # insertion_ours_image = image.copy() - image.copy()  # baseline
            # deletion_ours_image = image.copy()  # full image
            #
            # insertion_area.append(
            #     (insertion_ours_image.sum(-1) != 0).sum() / (image.shape[0] * image.shape[1]))
            # deletion_area.append(
            #     (deletion_ours_image.sum(-1) != 0).sum() / (image.shape[0] * image.shape[1]))
            #
            # for smdl_sub_mask in submodular_image_set:
            #     insertion_ours_image = insertion_ours_image + smdl_sub_mask
            #     deletion_ours_image = image - insertion_ours_image
            #
            #     insertion_area.append(
            #         (insertion_ours_image.sum(-1) != 0).sum() / (image.shape[0] * image.shape[1]))
            #     deletion_area.append(
            #         (deletion_ours_image.sum(-1) != 0).sum() / (image.shape[0] * image.shape[1]))

            # insertion_score = saved_json_file["consistency_score"]
            # deletion_score = saved_json_file["collaboration_score"]

        insertion_score = np.array([1 - deletion_score[-1]] + insertion_score)
        deletion_score = 1 - np.array([1 - insertion_score[-1]] + deletion_score)

        insertion_auc = metrics.auc(np.array(insertion_area), insertion_score)
        deletion_auc = metrics.auc(1 - np.array(deletion_area), deletion_score)
        insertion_aucs.append(insertion_auc)
        deletion_aucs.append(deletion_auc)

    insertion_auc_score = np.array(insertion_aucs).mean()
    deletion_auc_score = np.array(deletion_aucs).mean()
    print("Insertion AUC Score: {:.4f}\nDeletion AUC Score: {:.4f}".format(insertion_auc_score, deletion_auc_score))


if __name__ == "__main__":
    main()