# If these break, I used it wrong
import torch
from sklearn.metrics import confusion_matrix, accuracy_score
import Resnets3D.models.resnet as Resnets3D
import pickle
import os
import datetime
import numpy as np
import random
import shutil
import subprocess
import warnings

# If these break, I coded it wrong
import data_hooks
from load_data import *
from options import * 
from inference import transform_evaluate, attach_hooks, report_hook_results
from datasets.data_utils import batch_mean_sub, mean_sub
from create_network_dataset import form_corruption_name
from torch.utils.data import DataLoader

torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(42)
random.seed(42)

hmdb51_path = "models/hmdb51/2020_05_03_21_27_17.pth"
ucf101_path = "models/ucf101/2020_05_10_23_20_09.pth"
adversarial_paths = {8: "models/hmdb51/2020_09_13_22_43_14.pth",
        4: "checkpoints/hmdb51/2020_09_14_19_19_17_49.pth",
        2: "checkpoints/hmdb51/2020_09_15_23_01_12_49.pth"}
augmented = "models/hmdb51/2020_09_29_14_20_41.pth"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
low_augment = 'models/hmdb51/2020_09_23_19_14_02.pth'
high_augment = 'models/hmdb51/2020_09_23_22_18_01.pth'
targeted_dict = {
        'network': {
            0.2: 'models/hmdb51/2020_09_29_12_45_37.pth',
            0.1: 'models/hmdb51/2020_09_24_17_00_12.pth',
            0.01: 'models/hmdb51/2020_09_24_18_27_58.pth',
            0.001: 'models/hmdb51/2020_09_24_19_58_33.pth',
            0.0001:'models/hmdb51/2020_09_24_21_28_46.pth'
            },
        'random': {
            1e-4: 'models/hmdb51/2020_09_28_00_39_00.pth',
            5e-5: 'models/hmdb51/2020_09_29_01_33_29.pth',
            1e-5: 'models/hmdb51/2020_09_27_16_11_35.pth',
            5e-6: 'models/hmdb51/2020_09_26_13_25_51.pth',
            1e-6: 'models/hmdb51/2020_09_26_10_47_14.pth',
            5e-7: 'models/hmdb51/2020_09_26_08_13_16.pth',
            1e-7: 'models/hmdb51/2020_09_26_05_40_58.pth',
            5e-8: 'models/hmdb51/2020_09_25_21_39_04.pth',
            1e-8: 'models/hmdb51/2020_09_26_00_25_37.pth'
            },
        'contiguous': {
            0.9: 'models/hmdb51/2020_09_25_01_19_13.pth',
            0.75: 'models/hmdb51/2020_09_25_09_39_07.pth',
            0.5: 'models/hmdb51/2020_09_25_13_21_10.pth',
            0.25: 'models/hmdb51/2020_09_25_16_10_29.pth',
            0.1: 'models/hmdb51/2020_09_25_18_56_19.pth',
            0.01: 'models/hmdb51/2020_09_25_21_39_04.pth'
            }
        }

def load_model(args):
    if args.dataset == 'hmdb51':
        if args.adversarial:
            PATH = adversarial_paths[args.adversarial]
        elif args.corruption_augmented:
            PATH = augmented
        elif args.bit_corrupt_levels == 'low':
            PATH = low_augment
        elif args.bit_corrupt_levels == 'high':
            PATH = high_augment
        elif args.targeted:
            if args.experiment_mode == 'file':
                PATH = targeted_dict[args.corrupt_mode][args.corrupt_prob]
            elif args.experiment_mode == 'network':
                PATH = targeted_dict['network'][args.packet_loss_rate]
        else:
            PATH = hmdb51_path
    elif args.dataset == 'ucf101':
        PATH = ucf101_path
    else:
        raise NotImplementedError()
    model = Resnets3D.resnet18(num_classes=args.n_classes, sample_size=args.load_width, sample_duration=args.max_frames)
    model.load_state_dict(torch.load(PATH))
    model = model.to(device)
    model.eval()
    return model

def get_network_corruption_dataloaders(args):
    data = load_data(args)
    _, test_dataloader, _,  _ = build_dataloader(data, args)
    fpath = os.path.join(args.base_path, args.img_folder)
    corruption_name = form_corruption_name(args)
    new_path = os.path.join(args.base_path, 'network_corruptions', corruption_name)
    dataset_args = action_recognition_kwargs(args)
    print("Loading network-corrupted dataset from root directory", new_path)
    # sanity-check for files in directory
    check_nfiles = subprocess.Popen(['find', '-L', os.path.join(new_path, args.dataset), '-type', 'f'], stdout=subprocess.PIPE)
    count_files = subprocess.check_output(['wc', '-l'], stdin=check_nfiles.stdout).decode('utf-8').strip()
    print(count_files, "files found")
    if int(count_files) == 0: warnings.warn("There are no files in {}; double-check your video file path if this is not intentional.")
    if args.dataset == 'hmdb51':
        aug = HMDB51Dataset(os.path.join(new_path, 'hmdb51'), os.path.join(fpath,  "train_test_splits"), 'test', args.load_width, args.load_height, **dataset_args)
    elif args.dataset == 'ucf101':
        aug = UCF101Dataset(os.path.join(new_path, 'ucf101'), os.path.join(fpath, "data"), 'test', args.load_width, args.load_height, **dataset_args)
    else:
        raise NotImplementedError()

    aug_dataloader = DataLoader(aug, batch_size=1, num_workers=args.num_workers, shuffle=False, collate_fn=validation_collator)
    args.mode = 'sequence' # image not supported
    return test_dataloader, aug_dataloader

def get_file_corruption_dataloaders(args):
    data = load_data(args)
    _, test_dataloader, _,  _ = build_dataloader(data, args)

    fpath = os.path.join(args.base_path, args.img_folder)
    args.apply_prob = 1.
    dataset_args = action_recognition_kwargs(args)
    if args.dataset == 'hmdb51':

        aug = HMDB51Dataset(os.path.join(fpath, "hmdb51"), 
            os.path.join(fpath, "train_test_splits"), 'test', args.load_width, 
            args.load_height, **dataset_args)
        args.mode = 'sequence' # image not supported
    elif args.dataset == 'ucf101':
        dataset_args = action_recognition_kwargs(args)
        aug = UCF101Dataset(os.path.join(fpath, "videos"), os.path.join(fpath, "data"),
            'test', args.load_width, args.load_height, **dataset_args)
        args.mode = 'sequence'
    else:
        raise NotImplementedError() 
    aug_dataloader = DataLoader(aug, batch_size=1, num_workers=args.num_workers, shuffle=False, collate_fn=validation_collator)
    return test_dataloader, aug_dataloader


def run_experiment(test_dataloader, aug_dataloader, model, fn, data_hooks=None, examples=None):
    _, test_preds, test_y, _, stats = transform_evaluate(test_dataloader, aug_dataloader, model, fn, data_hooks=data_hooks, restore_training=False, examples=examples, mean_sub=True)
    cm = confusion_matrix(test_y, test_preds)
    return cm, test_preds, test_y, stats

if __name__ == '__main__':
    args = get_args()
    print(prettyprint_args(args))
    t = get_transform(args)
    print("Transformation function:", t)
    model = load_model(args)
    print("Model:", type(model))
    if args.experiment_mode == 'file':
        test_dataloader, aug_dataloader = get_file_corruption_dataloaders(args)
    elif args.experiment_mode == 'network':
        test_dataloader, aug_dataloader = get_network_corruption_dataloaders(args)
    else:
        raise NotImplementedError()
    hooks = attach_hooks(args)


    print("Starting experiment")
    cm, preds, y, stats = run_experiment(test_dataloader, aug_dataloader, model, t, data_hooks=hooks, examples=args.limit)
    results = {'cm': cm,
        'preds': preds,
        'y': y,
        'stats': stats,
        'args': args,
        'aug_info': t.__repr__()
    }
    for hook_name, hook in hooks.items():
        results['stats'][hook_name] = hook.results
    print("="*70)
    if not args.lightweight:
        time_str = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
        save_dir = 'perturbation_experiments/'
        fname = os.path.join(save_dir, time_str + ".pkl")
        with open(fname, 'wb') as f:
            pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)
        print("Results log saved to ", fname)
    report_hook_results(hooks, results)
