import argparse
import os
import numpy as np
import sys

import pandas as pd
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim

#config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf
from tinyimagenet_utils import transform_train, transform_test, TINdata, DistillTINdata, WarmUpLR, ModelwNorm, transform_train_aug
from tinyimagenet.models.model_selector import get_network

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@torch.no_grad()
def compute_norm_dist(testloader, model1, model2, criterion, len_data, args):
    # switch to evaluate mode
    model1.eval()
    model2.eval()

    num_class = args.num_class
    batch_size = args.batch_size

    losses = AverageMeter()
    res_np1 = np.zeros((len_data, 64))
    res_np2 = np.zeros((len_data, 128))
    res_np3 = np.zeros((len_data, 256))
    res_np4 = np.zeros((len_data, 512))
    res_np5 = np.zeros((len_data, 512))
    labels_np = np.zeros((len_data))
    # print(infer_np.shape)

    for batch_ind, (inputs, targets) in enumerate(testloader):
        # compute output
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)

        outputs_1, (fea1_1, fea2_1, fea3_1, fea4_1, fea5_1) = model1.multi_forward(inputs)
        # outputs_2, (fea1_2, fea2_2, fea3_2, fea4_2) = model2.multi_forward(inputs)

        dist1 = global_pooling(fea1_1).detach().cpu().numpy()
        dist2 = global_pooling(fea2_1).detach().cpu().numpy()
        dist3 = global_pooling(fea3_1).detach().cpu().numpy()
        dist4 = global_pooling(fea4_1).detach().cpu().numpy()
        dist5 = global_pooling(fea5_1).detach().cpu().numpy()

        # print(inputs.size(0), outputs.size(0), batch_ind*batch_size - (batch_ind*batch_size+inputs.shape[0]))
        res_np1[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist1
        res_np2[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist2
        res_np3[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist3
        res_np4[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist4
        res_np5[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = dist5
        labels_np[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = targets.detach().cpu().numpy()

    return res_np1, res_np2, res_np3, res_np4, res_np5, labels_np


def global_pooling(x, dim=128):
    # B, H, W, C = x.shape
    # x = x.reshape(B, -1, C)
    # x = x.permute(0, 3, 1, 2)
    y = x.mean([2, 3])
    return y


def euclidean_distance_norm(vector1, vector2):
    b, c, h, w = vector1.size()
    return torch.sqrt(torch.sum((vector1.reshape(b, c*h*w) - vector2.reshape(b, c*h*w)) ** 2, dim=1, keepdim=False) / (c * h * w))


def create_directory_structure(directory_path):
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print(f"Directory '{directory_path}' created successfully.")
    else:
        print(f"Directory '{directory_path}' already exists.")


def main():
    parser = argparse.ArgumentParser(description='setting for tinyimagenet')
    parser.add_argument('--model', type=str, default='mobilenetv3_small')
    parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--print_epoch', type=int, default=5,
                        help='print single model training stats per print_epoch training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=1, help='number workers')
    parser.add_argument('--num_class', type=int, default=200, help='num class')
    parser.add_argument('--num_runs', type=int, default=1)
    # conf
    parser.add_argument('--data_aug', type=bool, default=True, help='turn on data augmentation')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    attack_epochs = args.attack_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    warmup = args.warmup
    num_worker = args.num_worker
    args.data_aug=True

    DATASET_PATH = os.path.join(root_dir, 'tinyimagenet', 'data')
    checkpoint_path = os.path.join(args.save_path, 'tinyimagenet_fea', args.model, 'pooling_fea',
                                   'aug' if args.data_aug else 'no_aug')
    checkpoint_path1 = os.path.join(args.load_path, 'tinyimagenet', args.model, 'undefend',
                                    'aug' if args.data_aug else 'no_aug')
    checkpoint_path2 = os.path.join(args.load_path, 'tinyimagenet_half1', args.model, 'undefend',
                                    'aug' if args.data_aug else 'no_aug')

    print(checkpoint_path)
    #print(checkpoint_path1)
    #print(checkpoint_path2)

    train_data_1 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data_half1.npy'))
    train_label_1 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label_half1.npy'))
    train_data_2 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data_half2.npy'))
    train_label_2 = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label_half2.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    #print first 20 labels for each subset, for checking with other experiments
    print(train_label_1[:20])
    print(train_label_2[:20])

    # if data augmented
    trainset1 = TINdata(train_data_1, train_label_1, transform_train)
    trainset2 = TINdata(train_data_2, train_label_2, transform_train)
    alltestset = TINdata(all_test_data, all_test_label, transform_test)

    trainloader1 = torch.utils.data.DataLoader(trainset1, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    trainloader2 = torch.utils.data.DataLoader(trainset2, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    create_directory_structure(checkpoint_path)

    avg_fea1_half1 = np.zeros((len(trainset1), 128))
    avg_fea2_half1 = np.zeros((len(trainset1), 320))
    avg_fea3_half1 = np.zeros((len(trainset1), 512))
    avg_fea4_half1 = np.zeros((len(trainset1), 512))
    avg_fea1_half2 = np.zeros((len(trainset2), 128))
    avg_fea2_half2 = np.zeros((len(trainset2), 320))
    avg_fea3_half2 = np.zeros((len(trainset2), 512))
    avg_fea4_half2 = np.zeros((len(trainset2), 512))
    avg_fea1_test = np.zeros((len(alltestset), 128))
    avg_fea2_test = np.zeros((len(alltestset), 320))
    avg_fea3_test = np.zeros((len(alltestset), 512))
    avg_fea4_test = np.zeros((len(alltestset), 512))
    num_runs = args.num_runs
    for i in range(1, num_runs + 1):
        cur_cp = os.path.join(checkpoint_path1, str(i))
        criterion = nn.CrossEntropyLoss(reduction='none').to(device, torch.float)
        net_1 = get_network(args.model, num_classes=args.num_class)
        net_1 = net_1
        resume = cur_cp + '/model_last.pth.tar'
        print('==> Resuming from checkpoint' + resume)
        assert os.path.isfile(resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(resume, map_location='cpu')
        net_1.load_state_dict(checkpoint['state_dict'])
        net_1 = net_1.to(device, torch.float)

        cur_cp = os.path.join(checkpoint_path2, str(i))
        net_2 = get_network(args.model, num_classes=args.num_class)
        net_2 = net_2
        #resume = cur_cp + '/model_last.pth.tar'
        #print('==> Resuming from checkpoint' + resume)
        #assert os.path.isfile(resume), 'Error: no checkpoint directory found!'
        #checkpoint = torch.load(resume, map_location='cpu')
        #net_2.load_state_dict(checkpoint['state_dict'])
        #net_2 = net_2.to(device, torch.float)

        res_fea1_half1, res_fea2_half1, res_fea3_half1, res_fea4_half1, res_fea5_half1, labels_half1 = \
            compute_norm_dist(trainloader1, net_1, net_2, criterion, len(trainset1), args)
        res_fea1_half2, res_fea2_half2, res_fea3_half2, res_fea4_half2, res_fea5_half2, labels_half2 = \
            compute_norm_dist(trainloader2, net_1, net_2, criterion, len(trainset2), args)
        res_fea1_test, res_fea2_test, res_fea3_test, res_fea4_test, res_fea5_test, labels_test = \
            compute_norm_dist(alltestloader, net_1, net_2, criterion, len(alltestset), args)

        df_fea1_train = np.concatenate([res_fea1_half1, res_fea1_half2], axis=0)
        df_fea2_train = np.concatenate([res_fea2_half1, res_fea2_half2], axis=0)
        df_fea3_train = np.concatenate([res_fea3_half1, res_fea3_half2], axis=0)
        df_fea4_train = np.concatenate([res_fea4_half1, res_fea4_half2], axis=0)
        df_fea5_train = np.concatenate([res_fea5_half1, res_fea5_half2], axis=0)
        df_label_train = np.concatenate([labels_half1, labels_half2], axis=0)

        df_fea1_test = res_fea1_test
        df_fea2_test = res_fea2_test
        df_fea3_test = res_fea3_test
        df_fea4_test = res_fea4_test
        df_fea5_test = res_fea5_test
        df_label_test = labels_test

        if not os.path.exists(f'{checkpoint_path}/{i}'):
            os.makedirs(f'{checkpoint_path}/{i}')

        np.save(f'{checkpoint_path}/{i}/fea1_train.npy', df_fea1_train)
        np.save(f'{checkpoint_path}/{i}/fea2_train.npy', df_fea2_train)
        np.save(f'{checkpoint_path}/{i}/fea3_train.npy', df_fea3_train)
        np.save(f'{checkpoint_path}/{i}/fea4_train.npy', df_fea4_train)
        np.save(f'{checkpoint_path}/{i}/fea5_train.npy', df_fea5_train)
        np.save(f'{checkpoint_path}/{i}/label_train.npy', df_label_train)
        np.save(f'{checkpoint_path}/{i}/fea1_test.npy', df_fea1_test)
        np.save(f'{checkpoint_path}/{i}/fea2_test.npy', df_fea2_test)
        np.save(f'{checkpoint_path}/{i}/fea3_test.npy', df_fea3_test)
        np.save(f'{checkpoint_path}/{i}/fea4_test.npy', df_fea4_test)
        np.save(f'{checkpoint_path}/{i}/fea5_test.npy', df_fea5_test)
        np.save(f'{checkpoint_path}/{i}/label_test.npy', df_label_test)


if __name__ == '__main__':
    main()
