import argparse
import os
import numpy as np
import sys

import pandas as pd
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim

# config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf
from tinyimagenet_utils import transform_train, transform_test, TINdata, DistillTINdata, WarmUpLR, ModelwNorm, \
    transform_train_aug
from tinyimagenet.models.model_selector import get_network

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_fea = 0


@torch.no_grad()
def compute_norm_dist(testloader, model1, model2, criterion, len_data, args):
    # switch to evaluate mode
    model1.eval()
    model2.eval()

    num_class = args.num_class
    batch_size = args.batch_size

    losses = AverageMeter()
    res_np = [np.zeros((len_data)) for _ in range(num_fea)]

    for batch_ind, (inputs, targets) in enumerate(testloader):
        # compute output
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)

        outputs_1, feas1 = model1.last_forward(inputs)
        outputs_2, feas2 = model2.last_forward(inputs)

        feas1 = list(feas1)
        feas2 = list(feas2)

        for i in range(num_fea):
            dist_i = euclidean_distance_norm(feas1[i], feas2[i]).detach().cpu().numpy()
            res_np[i][batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist_i

    return res_np


def euclidean_distance_norm(vector1, vector2):
    b, c, h, w = vector1.size()
    return torch.sqrt(
        torch.sum((vector1.reshape(b, c * h * w) - vector2.reshape(b, c * h * w)) ** 2, dim=1, keepdim=False) / (
                c * h * w))


def create_directory_structure(directory_path):
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print(f"Directory '{directory_path}' created successfully.")
    else:
        print(f"Directory '{directory_path}' already exists.")


def main():
    parser = argparse.ArgumentParser(description='setting for tinyimagenet')
    parser.add_argument('--model', type=str, default='mobilenetv3_small')
    parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--print_epoch', type=int, default=5,
                        help='print single model training stats per print_epoch training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=1, help='number workers')
    parser.add_argument('--num_class', type=int, default=200, help='num class')
    parser.add_argument('--num_runs', type=int, default=1)
    # conf
    parser.add_argument('--data_aug', type=bool, default=True, help='turn on data augmentation')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    attack_epochs = args.attack_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    warmup = args.warmup
    num_worker = args.num_worker
    args.data_aug = True
    global num_fea
    num_fea = 4 if args.model=='vgg11_bn' else 8

    DATASET_PATH = os.path.join(root_dir, 'tinyimagenet', 'data')
    checkpoint_path = os.path.join(args.save_path, 'tinyimagenet', args.model, 'stat_fea_last',
                                   'aug' if args.data_aug else 'no_aug')
    checkpoint_path1 = os.path.join(args.load_path, 'tinyimagenet', args.model, 'undefend',
                                    'aug' if args.data_aug else 'no_aug')
    checkpoint_path2 = os.path.join(args.load_path, 'tinyimagenet_half1', args.model, 'undefend',
                                    'aug' if args.data_aug else 'no_aug')

    print(checkpoint_path)
    # print(checkpoint_path1)
    # print(checkpoint_path2)

    train_data_1 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data_half1.npy'))
    train_label_1 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label_half1.npy'))
    train_data_2 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data_half2.npy'))
    train_label_2 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label_half2.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    # print first 20 labels for each subset, for checking with other experiments
    print(train_label_1[:20])
    print(train_label_2[:20])

    # if data augmented
    trainset1 = TINdata(train_data_1, train_label_1, transform_train)
    trainset2 = TINdata(train_data_2, train_label_2, transform_train)
    alltestset = TINdata(all_test_data, all_test_label, transform_test)

    trainloader1 = torch.utils.data.DataLoader(trainset1, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    trainloader2 = torch.utils.data.DataLoader(trainset2, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False,
                                                num_workers=num_worker)

    create_directory_structure(checkpoint_path)

    avg_feas_half1 = [np.zeros((len(trainset1))) for _ in range(num_fea)]
    avg_feas_half2 = [np.zeros((len(trainset2))) for _ in range(num_fea)]
    avg_feas_test = [np.zeros((len(alltestset))) for _ in range(num_fea)]
    num_runs = args.num_runs
    for i in range(1, num_runs + 1):
        cur_cp = os.path.join(checkpoint_path1, str(i))
        criterion = nn.CrossEntropyLoss(reduction='none').to(device, torch.float)
        net_1 = get_network(args.model, num_classes=args.num_class)
        net_1 = net_1
        resume = cur_cp + '/model_last.pth.tar'
        print('==> Resuming from checkpoint' + resume)
        assert os.path.isfile(resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(resume, map_location='cpu')
        net_1.load_state_dict(checkpoint['state_dict'])
        net_1 = net_1.to(device, torch.float)

        cur_cp = os.path.join(checkpoint_path2, str(i))
        net_2 = get_network(args.model, num_classes=args.num_class)
        net_2 = net_2
        resume = cur_cp + '/model_last.pth.tar'
        print('==> Resuming from checkpoint' + resume)
        assert os.path.isfile(resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(resume, map_location='cpu')
        net_2.load_state_dict(checkpoint['state_dict'])
        net_2 = net_2.to(device, torch.float)

        res_feas_half1 = compute_norm_dist(trainloader1, net_1, net_2, criterion, len(trainset1), args)
        res_feas_half2 = compute_norm_dist(trainloader2, net_1, net_2, criterion, len(trainset2), args)
        res_feas_test = compute_norm_dist(alltestloader, net_1, net_2, criterion, len(alltestset), args)

        for i in range(len(res_feas_half1)):
            avg_feas_half1[i] += res_feas_half1[i]
            avg_feas_half2[i] += res_feas_half2[i]
            avg_feas_test[i] += res_feas_test[i]

    for i in range(num_fea):
        avg_feas_half1[i] /= num_runs
        avg_feas_half2[i] /= num_runs
        avg_feas_test[i] /= num_runs

    df_feas_half1 = []
    df_feas_half2 = []
    df_feas_test = []
    for i in range(num_fea):
        df_feas_half1.append(pd.DataFrame(list(avg_feas_half1[i]), columns=['value']))
        df_feas_half2.append(pd.DataFrame(list(avg_feas_half2[i]), columns=['value']))
        df_feas_test.append(pd.DataFrame(list(avg_feas_test[i]), columns=['value']))

    for i in range(num_fea):
        df_feas_half1[i].to_csv(f'{checkpoint_path}/stat_fea{i+1}_half1_{num_runs}.csv', index=False)
        df_feas_half2[i].to_csv(f'{checkpoint_path}/stat_fea{i+1}_half2_{num_runs}.csv', index=False)
        df_feas_test[i].to_csv(f'{checkpoint_path}/stat_fea{i+1}_half_test_{num_runs}.csv', index=False)


if __name__ == '__main__':
    main()
