from datasets.frame_dataset import MOTDataset
from options import get_mot15_args
import os
from data_hooks import TruncatedL2
from tqdm import tqdm
from multiprocessing import Pool

def form_corruption_name(args):
    if args.experiment_mode == 'file':
        corrupt_name = "file_corruptions/{}_{}_ver{}".format(args.corrupt_mode, args.p, args.version)
    elif args.experiment_mode == 'network':
        corrupt_name = "network_corruptions/loss_uplink_{}_ver{}".format(args.p, args.version)
    else:
        raise NotImplementedError()
    return corrupt_name

def load_single_sequence(i):
    if i % 2 == 0:
        return clean[i // 2]
    else:
        return corrupt[i // 2]

def l2_comparison(clean, corrupt):
    hook = TruncatedL2(ndim=4, device='cpu')
    assert len(clean) == len(corrupt)
    tensors = []
    with Pool(12) as p:
        tensors = p.map(load_single_sequence, list(range(len(clean) + len(corrupt))))
    for i in range(len(clean)):
        X, clean_path = tensors[2 * i]
        Xt, corrupt_path = tensors[2 * i + 1]
        l2 = hook(X.permute(1, 0, 2, 3), Xt.permute(1, 0, 2, 3), None, None, None)
        hook.update(l2)
    return hook

def load_mot15(args):
    corruption_name = form_corruption_name(args)
    vanilla_folder = os.path.join(args.base_path, "FairMOT/src/data/MOT15")
    img_folder = os.path.join(args.base_path, corruption_name,  "MOT15")
    if args.verbose:
        print("Vanilla dataset location:", vanilla_folder)
        print("Corrupted dataset location:", img_folder)
    return MOTDataset(vanilla_folder), MOTDataset(img_folder)

if __name__ == '__main__':
    args = get_mot15_args()
    clean, corrupt = load_mot15(args)
    hook = l2_comparison(clean, corrupt)
    hook.report_results(None, None)
