"""
Here explain what is M to N (M2N) attack.

M2N attack, means M patterns to N targets where M <= N.

"""

import os

import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision
from torchvision.transforms import Compose, ToTensor, PILToTensor, RandomHorizontalFlip, ToPILImage, Resize
from torchvision.models import resnet18, resnet34
from torchvision.utils import save_image

import core
from upload_model import ExperimentModel
import argparse
from core.attacks import Arcueid
import numpy as np
import pandas as pd

dataset = torchvision.datasets.DatasetFolder

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'tinyimagenet'])
parser.add_argument('--model', type=str, default='resnet18', choices=['resnet18', 'resnet34', 'vgg13_bn', 'ViT', 'SimpleViT']) # no more CCT
parser.add_argument('--gpu', type=str, default='0')
# parser.add_argument('--mode', type=str, default='DIRTY', choices=['DIRTY', 'CLEAN'])
args = parser.parse_args()

csv_name = 'experiment_data/Arcueid_M2N.csv'

dataset_name = args.dataset
model_name = args.model
GPU = args.gpu

attacker_dataset_name = 'cifar100' if dataset_name != 'cifar100' else 'tinyimagenet'
if model_name in ['resnet18', 'resnet34', 'vgg13_bn']: # for CNN models
    attacker_model_name = 'resnet34' if model_name != 'resnet34' else 'resnet18'
elif model_name in ['ViT', 'SimpleViT']: # for ViT models
    attacker_model_name = 'SimpleViT' if model_name != 'SimpleViT' else 'ViT'

DATASET_PATH = {
    'cifar10': {
        'train': '../.local/share/cifar10/train/',
        'test': '../.local/share/cifar10/test/'
    },
    
    'cifar100': {
        'train': '../.local/share/cifar100/train/',
        'test': '../.local/share/cifar100/test/'
    },
    
    'tinyimagenet': {
        'train': '../.local/share/tinyimagenet/train/',
        'test': '../.local/share/tinyimagenet/val/'
    }
}

# objections
if dataset_name == 'cifar10':
    # 10 -> 1 # special case
    # 10 -> 2
    # 10 -> 5
    targets_test = [
        [0 for _ in range(10)],
        [(i // 5) for i in range(10)],
        [(i // 2) for i in range(10)],
        list(range(10))
    ]

elif dataset_name == 'cifar100':
    # 100 -> 1 # special case
    # 100 -> 10
    # 100 -> 20
    targets_test = [
        [0 for _ in range(100)],
        [(i // 10) for i in range(100)],
        [(i // 20) for i in range(100)],
        list(range(100))
    ]
elif dataset_name == 'tinyimagenet':
    # 200 -> 1 # special case
    # 200 -> 50
    # 200 -> 100
    targets_test = [
        [0 for _ in range(200)],
        [(i // 50) for i in range(200)],
        [(i // 100) for i in range(200)],
        list(range(200))
    ]

# image file -> cv.imread -> numpy.ndarray (H x W x C) -> ToTensor -> torch.Tensor (C x H x W) -> RandomHorizontalFlip -> torch.Tensor -> network input
transform_train = Compose([
    ToPILImage(),
    Resize((32, 32)),
    ToTensor(),
    RandomHorizontalFlip()
])

trainset = dataset(
    root=DATASET_PATH[dataset_name]['train'],
    loader=cv2.imread,
    extensions=('png', 'jpeg'),
    transform=transform_train,
    target_transform=None,
    is_valid_file=None
    )

transform_test = Compose([
    ToPILImage(),
    Resize((32, 32)),
    ToTensor()
])
testset = dataset(
    root=DATASET_PATH[dataset_name]['test'],
    loader=cv2.imread,
    extensions=('png', 'jpeg'),
    transform=transform_test,
    target_transform=None,
    is_valid_file=None
    )

attacker_dataset = dataset(
    root=DATASET_PATH[attacker_dataset_name]['train'],
    loader=cv2.imread,
    extensions=('png', 'jpeg'),
    transform=transform_train,
    target_transform=None,
    is_valid_file=None
)

print(f'targets_test: {targets_test}')
for target_labels in targets_test:

    trigger_size = 32
    trigger_weight = 0.15

    # trigger for tensor after ToTensor, with object range [0.0, 1.0]
    pattern = torch.zeros((3, 32, 32), dtype=torch.float32)
    pattern[:, -trigger_size:, -trigger_size:] = 1.0
    weight = torch.ones((3, 32, 32), dtype=torch.float32)
    weight[:, -trigger_size:, -trigger_size:] = (1 - trigger_weight)

    print(f'target_labels: {target_labels}')
    arcueid = Arcueid(
        train_dataset=trainset,
        test_dataset=testset,
        model=ExperimentModel(model_name, dataset_name, pretrained=False)(),
        loss=nn.CrossEntropyLoss(),
        poisoned_rate_per_id=0.0001,
        # poison_ratio=0.0001,
        trigger_info={
            'pattern': pattern,
            'weight': weight
        },

        # target_labels=[0, 3, 6],
        target_labels=target_labels,
        
        # optimize arguments
        train_steps=10,
        train_scale=0.3,
        
        optimize_model=ExperimentModel(attacker_model_name, attacker_dataset_name, pretrained=True)(),
        optimize_dataset=attacker_dataset,
        optimize_device=torch.device(f'cuda:{GPU}')
    )
    # if model is CNN model, use the following schedule
    if model_name in ['resnet18', 'resnet34', 'vgg13_bn']:
        
        schedule = {
            'device': 'GPU',
            # 'CUDA_VISIBLE_DEVICES': '0',
            'CUDA_SELECTED_DEVICES': GPU,
            'GPU_num': 1,

            'benign_training': False,
            'batch_size': 1024,
            'num_workers': 16,

            'lr': 0.1,
            'momentum': 0.9,
            'weight_decay': 5e-4,
            'gamma': 0.1,
            'schedule': [150, 180],

            'epochs': 200,

            'log_iteration_interval': 100,
            'test_epoch_interval': 10,
            'save_epoch_interval': 10,

            'save_dir': 'experiments',
            'experiment_name': f'train_poisoned_DatasetFolder-{dataset_name}_{model_name}_Arcueid'
        }
    # if model is ViT model, use the following schedule
    elif model_name in ['ViT', 'SimpleViT', 'CCT']:
        schedule = {
            'device': 'GPU',
            # 'CUDA_VISIBLE_DEVICES': '0',
            'CUDA_SELECTED_DEVICES': GPU,
            'GPU_num': 1,

            'benign_training': False,
            'batch_size': 1024,
            'num_workers': 4,

            'optimizer': 'AdamW',
            'lr': 3e-4,
            'weight_decay': 0.01,
            'gamma': 0.1,
            'scheduler': 'cosine_annealing',
            'min_lr': 1e-6,  
            'warmup_epoch': 5, 

            'epochs': 200,

            'log_iteration_interval': 100,
            'test_epoch_interval': 10,
            'save_epoch_interval': 10,

            'save_dir': 'experiments',
            'experiment_name': f'train_poisoned_DatasetFolder-{dataset_name}_Arcueid_{model_name}'
        }
    # generate an example of poisoned dataset
    # poisoned_dataset = arcueid.poisoned_test_dataset
    # save_image(poisoned_dataset[0][0], 'test_Arcueid.png')
    # # and give a orignal image to compare
    # save_image(testset[0][0], 'test_Arcueid_original.png')

    epoch_list, acc_list, _ = arcueid.train(schedule)
    # arcueid.test(schedule)
    asr_list = arcueid.compute_asr()

    acc = acc_list[-1]

    # get asr string like {average_asr±std_asr}
    record_dict = {
        'dataset_name': dataset_name,
        'model_name': model_name,
        'objections': f'{len(target_labels)} -> {len(set(target_labels))}',
        'acc': acc,
        'asr': f'{np.mean(asr_list)*100:.1f}%±{np.std(asr_list)*100:.1f}%'
    }

    record_df = pd.DataFrame([record_dict])
    if os.path.exists(csv_name):
        record_df.to_csv(csv_name, mode='a', header=False, index=False)
    else:
        record_df.to_csv(csv_name, mode='w', index=False)