import argparse
import os
import numpy as np
import sys

import pandas as pd
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim

#config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf
from tinyimagenet_utils import transform_train, transform_test, TINdata, DistillTINdata, WarmUpLR, ModelwNorm, transform_train_aug
from tinyimagenet.models.model_selector import get_network

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@torch.no_grad()
def compute_norm_dist(testloader, model1, model2, criterion, len_data, args):
    # switch to evaluate mode
    model1.eval()
    model2.eval()

    num_class = args.num_class
    batch_size = args.batch_size

    losses = AverageMeter()
    res_np1 = np.zeros((len_data))
    res_np2 = np.zeros((len_data))
    res_np3 = np.zeros((len_data))
    res_np4 = np.zeros((len_data))
    res_np5 = np.zeros((len_data))
    # print(infer_np.shape)

    for batch_ind, (inputs, targets) in enumerate(testloader):
        # compute output
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)

        outputs_1, (fea1_1, fea2_1, fea3_1, fea4_1, fea5_1) = model1.multi_forward(inputs)
        outputs_2, (fea1_2, fea2_2, fea3_2, fea4_2, fea5_2) = model2.multi_forward(inputs)

        dist1 = euclidean_distance_norm(fea1_1, fea1_2).detach().cpu().numpy()
        dist2 = euclidean_distance_norm(fea2_1, fea2_2).detach().cpu().numpy()
        dist3 = euclidean_distance_norm(fea3_1, fea3_2).detach().cpu().numpy()
        dist4 = euclidean_distance_norm(fea4_1, fea4_2).detach().cpu().numpy()
        dist5 = euclidean_distance_norm(fea5_1, fea5_2).detach().cpu().numpy()

        # print(inputs.size(0), outputs.size(0), batch_ind*batch_size - (batch_ind*batch_size+inputs.shape[0]))
        res_np1[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist1
        res_np2[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist2
        res_np3[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist3
        res_np4[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist4
        res_np5[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist5

    return res_np1, res_np2, res_np3, res_np4, res_np5


def euclidean_distance_norm(vector1, vector2):
    b, c, h, w = vector1.size()
    return torch.sqrt(torch.sum((vector1.reshape(b, c*h*w) - vector2.reshape(b, c*h*w)) ** 2, dim=1, keepdim=False) / (c * h * w))


def create_directory_structure(directory_path):
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print(f"Directory '{directory_path}' created successfully.")
    else:
        print(f"Directory '{directory_path}' already exists.")


def main():
    parser = argparse.ArgumentParser(description='setting for tinyimagenet')
    parser.add_argument('--model', type=str, default='mobilenetv3_small')
    parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--print_epoch', type=int, default=5,
                        help='print single model training stats per print_epoch training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=1, help='number workers')
    parser.add_argument('--num_class', type=int, default=200, help='num class')
    parser.add_argument('--num_runs', type=int, default=1)
    # conf
    parser.add_argument('--data_aug', type=bool, default=True, help='turn on data augmentation')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    attack_epochs = args.attack_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    warmup = args.warmup
    num_worker = args.num_worker
    args.data_aug=True

    DATASET_PATH = os.path.join(root_dir, 'tinyimagenet', 'data')
    checkpoint_path = os.path.join(args.save_path, 'tinyimagenet', args.model, 'stat_fea',
                                   'aug' if args.data_aug else 'no_aug')
    checkpoint_path1 = os.path.join(args.load_path, 'tinyimagenet', args.model, 'undefend',
                                    'aug' if args.data_aug else 'no_aug')
    checkpoint_path2 = os.path.join(args.load_path, 'tinyimagenet_half1', args.model, 'undefend',
                                    'aug' if args.data_aug else 'no_aug')

    print(checkpoint_path)
    #print(checkpoint_path1)
    #print(checkpoint_path2)

    train_data_1 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data_half1.npy'))
    train_label_1 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label_half1.npy'))
    train_data_2 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data_half2.npy'))
    train_label_2 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label_half2.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    #print first 20 labels for each subset, for checking with other experiments
    print(train_label_1[:20])
    print(train_label_2[:20])

    # if data augmented
    trainset1 = TINdata(train_data_1, train_label_1, transform_train)
    trainset2 = TINdata(train_data_2, train_label_2, transform_train)
    alltestset = TINdata(all_test_data, all_test_label, transform_test)

    trainloader1 = torch.utils.data.DataLoader(trainset1, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    trainloader2 = torch.utils.data.DataLoader(trainset2, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    create_directory_structure(checkpoint_path)

    avg_fea1_half1 = np.zeros((len(trainset1)))
    avg_fea2_half1 = np.zeros((len(trainset1)))
    avg_fea3_half1 = np.zeros((len(trainset1)))
    avg_fea4_half1 = np.zeros((len(trainset1)))
    avg_fea5_half1 = np.zeros((len(trainset1)))
    avg_fea1_half2 = np.zeros((len(trainset2)))
    avg_fea2_half2 = np.zeros((len(trainset2)))
    avg_fea3_half2 = np.zeros((len(trainset2)))
    avg_fea4_half2 = np.zeros((len(trainset2)))
    avg_fea5_half2 = np.zeros((len(trainset2)))
    avg_fea1_test = np.zeros((len(alltestset)))
    avg_fea2_test = np.zeros((len(alltestset)))
    avg_fea3_test = np.zeros((len(alltestset)))
    avg_fea4_test = np.zeros((len(alltestset)))
    avg_fea5_test = np.zeros((len(alltestset)))
    num_runs = args.num_runs
    for i in range(1, num_runs + 1):
        cur_cp = os.path.join(checkpoint_path1, str(i))
        criterion = nn.CrossEntropyLoss(reduction='none').to(device, torch.float)
        net_1 = get_network(args.model, num_classes=args.num_class)
        net_1 = net_1
        resume = cur_cp + '/model_last.pth.tar'
        print('==> Resuming from checkpoint' + resume)
        assert os.path.isfile(resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(resume, map_location='cpu')
        net_1.load_state_dict(checkpoint['state_dict'])
        net_1 = net_1.to(device, torch.float)

        cur_cp = os.path.join(checkpoint_path2, str(i))
        net_2 = get_network(args.model, num_classes=args.num_class)
        net_2 = net_2
        resume = cur_cp + '/model_last.pth.tar'
        print('==> Resuming from checkpoint' + resume)
        assert os.path.isfile(resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(resume, map_location='cpu')
        net_2.load_state_dict(checkpoint['state_dict'])
        net_2 = net_2.to(device, torch.float)

        res_fea1_half1, res_fea2_half1, res_fea3_half1, res_fea4_half1, res_fea5_half1 = compute_norm_dist(trainloader1, net_1, net_2,
                                                                                           criterion, len(trainset1),
                                                                                           args)
        res_fea1_half2, res_fea2_half2, res_fea3_half2, res_fea4_half2, res_fea5_half2 = compute_norm_dist(trainloader2, net_1, net_2,
                                                                                           criterion, len(trainset2),
                                                                                           args)
        res_fea1_test, res_fea2_test, res_fea3_test, res_fea4_test, res_fea5_test = compute_norm_dist(alltestloader, net_1, net_2,
                                                                                       criterion, len(alltestset), args)
        avg_fea1_half1 += res_fea1_half1
        avg_fea2_half1 += res_fea2_half1
        avg_fea3_half1 += res_fea3_half1
        avg_fea4_half1 += res_fea4_half1
        avg_fea5_half1 += res_fea5_half1

        avg_fea1_half2 += res_fea1_half2
        avg_fea2_half2 += res_fea2_half2
        avg_fea3_half2 += res_fea3_half2
        avg_fea4_half2 += res_fea4_half2
        avg_fea5_half2 += res_fea5_half2

        avg_fea1_test += res_fea1_test
        avg_fea2_test += res_fea2_test
        avg_fea3_test += res_fea3_test
        avg_fea4_test += res_fea4_test
        avg_fea5_test += res_fea5_test

    avg_fea1_half1 /= num_runs
    avg_fea2_half1 /= num_runs
    avg_fea3_half1 /= num_runs
    avg_fea4_half1 /= num_runs
    avg_fea5_half1 /= num_runs
    avg_fea1_half2 /= num_runs
    avg_fea2_half2 /= num_runs
    avg_fea3_half2 /= num_runs
    avg_fea4_half2 /= num_runs
    avg_fea5_half2 /= num_runs

    avg_fea1_test /= num_runs
    avg_fea2_test /= num_runs
    avg_fea3_test /= num_runs
    avg_fea4_test /= num_runs
    avg_fea5_test /= num_runs

    df_fea1_half1 = pd.DataFrame(list(avg_fea1_half1), columns=['value'])
    df_fea2_half1 = pd.DataFrame(list(avg_fea2_half1), columns=['value'])
    df_fea3_half1 = pd.DataFrame(list(avg_fea3_half1), columns=['value'])
    df_fea4_half1 = pd.DataFrame(list(avg_fea4_half1), columns=['value'])
    df_fea5_half1 = pd.DataFrame(list(avg_fea5_half1), columns=['value'])

    df_fea1_half2 = pd.DataFrame(list(avg_fea1_half2), columns=['value'])
    df_fea2_half2 = pd.DataFrame(list(avg_fea2_half2), columns=['value'])
    df_fea3_half2 = pd.DataFrame(list(avg_fea3_half2), columns=['value'])
    df_fea4_half2 = pd.DataFrame(list(avg_fea4_half2), columns=['value'])
    df_fea5_half2 = pd.DataFrame(list(avg_fea5_half2), columns=['value'])

    df_fea1_test = pd.DataFrame(list(avg_fea1_test), columns=['value'])
    df_fea2_test = pd.DataFrame(list(avg_fea2_test), columns=['value'])
    df_fea3_test = pd.DataFrame(list(avg_fea3_test), columns=['value'])
    df_fea4_test = pd.DataFrame(list(avg_fea4_test), columns=['value'])
    df_fea5_test = pd.DataFrame(list(avg_fea5_test), columns=['value'])


    df_fea1_half1.to_csv(f'{checkpoint_path}/stat_fea1_half1_{num_runs}.csv', index=False)
    df_fea2_half1.to_csv(f'{checkpoint_path}/stat_fea2_half1_{num_runs}.csv', index=False)
    df_fea3_half1.to_csv(f'{checkpoint_path}/stat_fea3_half1_{num_runs}.csv', index=False)
    df_fea4_half1.to_csv(f'{checkpoint_path}/stat_fea4_half1_{num_runs}.csv', index=False)
    df_fea5_half1.to_csv(f'{checkpoint_path}/stat_fea5_half1_{num_runs}.csv', index=False)
    df_fea1_half2.to_csv(f'{checkpoint_path}/stat_fea1_half2_{num_runs}.csv', index=False)
    df_fea2_half2.to_csv(f'{checkpoint_path}/stat_fea2_half2_{num_runs}.csv', index=False)
    df_fea3_half2.to_csv(f'{checkpoint_path}/stat_fea3_half2_{num_runs}.csv', index=False)
    df_fea4_half2.to_csv(f'{checkpoint_path}/stat_fea4_half2_{num_runs}.csv', index=False)
    df_fea5_half2.to_csv(f'{checkpoint_path}/stat_fea5_half2_{num_runs}.csv', index=False)
    df_fea1_test.to_csv(f'{checkpoint_path}/stat_fea1_half_test_{num_runs}.csv', index=False)
    df_fea2_test.to_csv(f'{checkpoint_path}/stat_fea2_half_test_{num_runs}.csv', index=False)
    df_fea3_test.to_csv(f'{checkpoint_path}/stat_fea3_half_test_{num_runs}.csv', index=False)
    df_fea4_test.to_csv(f'{checkpoint_path}/stat_fea4_half_test_{num_runs}.csv', index=False)
    df_fea5_test.to_csv(f'{checkpoint_path}/stat_fea5_half_test_{num_runs}.csv', index=False)


if __name__ == '__main__':
    main()
