import argparse
import os
import numpy as np
import sys

import pandas as pd
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim

#config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf
from tinyimagenet_utils import transform_train, transform_test, TINdata, DistillTINdata, WarmUpLR, ModelwNorm, transform_train_aug
from tinyimagenet.models.model_selector import get_network

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@torch.no_grad()
def compute_norm_dist(testloader, model1, model2, criterion, len_data, args):
    # switch to evaluate mode
    model1.eval()
    model2.eval()

    num_class = args.num_class
    batch_size = args.batch_size

    losses = AverageMeter()
    res_np1 = np.zeros((len_data))
    res_np2 = np.zeros((len_data))
    res_np3 = np.zeros((len_data))
    res_np4 = np.zeros((len_data))
    # print(infer_np.shape)

    for batch_ind, (inputs, targets) in enumerate(testloader):
        # compute output
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)

        outputs_1, (fea1_1, fea2_1, fea3_1, fea4_1) = model1.multi_forward(inputs)
        outputs_2, (fea1_2, fea2_2, fea3_2, fea4_2) = model2.multi_forward(inputs)
        #print(fea1_1.size())
        dist1 = euclidean_distance_norm(fea1_1, fea1_2).detach().cpu().numpy()
        dist2 = euclidean_distance_norm(fea2_1, fea2_2).detach().cpu().numpy()
        dist3 = euclidean_distance_norm(fea3_1, fea3_2).detach().cpu().numpy()
        dist4 = euclidean_distance_norm(fea4_1, fea4_2).detach().cpu().numpy()

        # print(inputs.size(0), outputs.size(0), batch_ind*batch_size - (batch_ind*batch_size+inputs.shape[0]))
        res_np1[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist1
        res_np2[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist2
        res_np3[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist3
        res_np4[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist4

    return res_np1, res_np2, res_np3, res_np4


def euclidean_distance_norm(vector1, vector2):
    b, c, h, w = vector1.size()
    return torch.sqrt(torch.sum((vector1.reshape(b, c*h*w) - vector2.reshape(b, c*h*w)) ** 2, dim=1, keepdim=False) / (c * h * w))


def create_directory_structure(directory_path):
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print(f"Directory '{directory_path}' created successfully.")
    else:
        print(f"Directory '{directory_path}' already exists.")


def main():
    parser = argparse.ArgumentParser(description='setting for tinyimagenet')
    parser.add_argument('--model', type=str, default='mobilenetv3_small')
    parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--print_epoch', type=int, default=5,
                        help='print single model training stats per print_epoch training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=1, help='number workers')
    parser.add_argument('--num_class', type=int, default=200, help='num class')
    parser.add_argument('--num_runs', type=int, default=1)
    # conf
    parser.add_argument('--data_aug', type=bool, default=True, help='turn on data augmentation')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    attack_epochs = args.attack_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    warmup = args.warmup
    num_worker = args.num_worker
    args.data_aug=True

    DATASET_PATH = os.path.join(root_dir, 'tinyimagenet', 'data')
    checkpoint_path = os.path.join(args.save_path, 'tinyimagenet', args.model, 'stat_fea',
                                   'aug' if args.data_aug else 'no_aug')
    checkpoint_path1 = os.path.join(args.load_path, 'tinyimagenet', args.model, 'undefend',
                                    'aug' if args.data_aug else 'no_aug')
    #checkpoint_path2 = os.path.join(args.load_path, 'tinyimagenet_half1', args.model, 'undefend',
    #                                'aug' if args.data_aug else 'no_aug')

    print(checkpoint_path)
    #print(checkpoint_path1)
    #print(checkpoint_path2)

    train_data_1 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data_half1.npy'))
    train_label_1 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label_half1.npy'))
    train_data_2 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data_half2.npy'))
    train_label_2 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label_half2.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    #print first 20 labels for each subset, for checking with other experiments
    print(train_label_1[:20])
    print(train_label_2[:20])

    # if data augmented
    trainset1 = TINdata(train_data_1, train_label_1, transform_train)
    trainset2 = TINdata(train_data_2, train_label_2, transform_train)
    alltestset = TINdata(all_test_data, all_test_label, transform_test)

    trainloader1 = torch.utils.data.DataLoader(trainset1, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    trainloader2 = torch.utils.data.DataLoader(trainset2, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    create_directory_structure(checkpoint_path)

    avg_fea1_full = np.zeros((len(trainset1)+len(trainset2)))
    avg_fea2_full = np.zeros((len(trainset1)+len(trainset2)))
    avg_fea3_full = np.zeros((len(trainset1)+len(trainset2)))
    avg_fea4_full = np.zeros((len(trainset1)+len(trainset2)))
    avg_fea1_full_test = np.zeros((len(alltestset)))
    avg_fea2_full_test = np.zeros((len(alltestset)))
    avg_fea3_full_test = np.zeros((len(alltestset)))
    avg_fea4_full_test = np.zeros((len(alltestset)))
    num_runs = args.num_runs

    cur_cp = os.path.join(checkpoint_path1, str(1))
    criterion = nn.CrossEntropyLoss(reduction='none').to(device, torch.float)
    net_1 = get_network(args.model, num_classes=args.num_class)
    net_1 = net_1
    resume = cur_cp + '/model_last.pth.tar'
    print('==> Resuming from checkpoint' + resume)
    assert os.path.isfile(resume), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(resume, map_location='cpu')
    net_1.load_state_dict(checkpoint['state_dict'])
    net_1 = net_1.to(device, torch.float)

    cur_cp = os.path.join(checkpoint_path1, str(2))
    net_2 = get_network(args.model, num_classes=args.num_class)
    net_2 = net_2
    resume = cur_cp + '/model_last.pth.tar'
    print('==> Resuming from checkpoint' + resume)
    assert os.path.isfile(resume), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(resume, map_location='cpu')
    net_2.load_state_dict(checkpoint['state_dict'])
    net_2 = net_2.to(device, torch.float)

    res_fea1_full, res_fea2_full, res_fea3_full, res_fea4_full = compute_norm_dist(trainloader1, net_1, net_2, criterion, len(trainset1), args)
    avg_fea1_full[:len(res_fea1_full)] += res_fea1_full
    avg_fea2_full[:len(res_fea2_full)] += res_fea2_full
    avg_fea3_full[:len(res_fea3_full)] += res_fea3_full
    avg_fea4_full[:len(res_fea4_full)] += res_fea4_full
    res_fea1_full, res_fea2_full, res_fea3_full, res_fea4_full = compute_norm_dist(trainloader2, net_1, net_2, criterion, len(trainset2), args)
    avg_fea1_full[len(res_fea1_full):] += res_fea1_full
    avg_fea2_full[len(res_fea2_full):] += res_fea2_full
    avg_fea3_full[len(res_fea3_full):] += res_fea3_full
    avg_fea4_full[len(res_fea4_full):] += res_fea4_full
    res_fea1_full, res_fea2_full, res_fea3_full, res_fea4_full = compute_norm_dist(alltestloader, net_1, net_2, criterion, len(alltestset), args)
    avg_fea1_full_test += res_fea1_full
    avg_fea2_full_test += res_fea2_full
    avg_fea3_full_test += res_fea3_full
    avg_fea4_full_test += res_fea4_full

    #avg_fea1_full /= num_runs
    #avg_fea2_full /= num_runs
    #avg_fea3_full /= num_runs
    #avg_fea4_full /= num_runs

    df_fea1_full = pd.DataFrame(list(avg_fea1_full), columns=['value'])
    df_fea2_full = pd.DataFrame(list(avg_fea2_full), columns=['value'])
    df_fea3_full = pd.DataFrame(list(avg_fea3_full), columns=['value'])
    df_fea4_full = pd.DataFrame(list(avg_fea4_full), columns=['value'])
    df_fea1_full_test = pd.DataFrame(list(avg_fea1_full_test), columns=['value'])
    df_fea2_full_test = pd.DataFrame(list(avg_fea2_full_test), columns=['value'])
    df_fea3_full_test = pd.DataFrame(list(avg_fea3_full_test), columns=['value'])
    df_fea4_full_test = pd.DataFrame(list(avg_fea4_full_test), columns=['value'])

    df_fea1_full.to_csv(f'{checkpoint_path}/stat_fea1_full.csv', index=False)
    df_fea2_full.to_csv(f'{checkpoint_path}/stat_fea2_full.csv', index=False)
    df_fea3_full.to_csv(f'{checkpoint_path}/stat_fea3_full.csv', index=False)
    df_fea4_full.to_csv(f'{checkpoint_path}/stat_fea4_full.csv', index=False)
    df_fea1_full_test.to_csv(f'{checkpoint_path}/stat_fea1_full_test.csv', index=False)
    df_fea2_full_test.to_csv(f'{checkpoint_path}/stat_fea2_full_test.csv', index=False)
    df_fea3_full_test.to_csv(f'{checkpoint_path}/stat_fea3_full_test.csv', index=False)
    df_fea4_full_test.to_csv(f'{checkpoint_path}/stat_fea4_full_test.csv', index=False)


if __name__ == '__main__':
    main()
