import os
import torch
import itertools
import glob as gb
import numpy as np
import json
from datetime import datetime
from data import FrameData

def setup_path(args):
    dataset = args.dataset
    num_slots = args.num_slots
    iters = args.num_iterations
    batch_size = args.batch_size
    resolution = args.resolution
    flow_to_rgb = args.flow_to_rgb
    verbose = args.verbose if args.verbose else 'none'
    flow_to_rgb_text = 'rgb' if flow_to_rgb else 'uv'
    inference = args.inference

    # make all the essential folders, e.g. models, logs, results, etc.
    global dt_string, logPath, modelPath, resultsPath
    now = datetime.now()
    dt_string = now.strftime("%Y_%m_%d_%H")
    dir_name = f'{dt_string}-{dataset}-{flow_to_rgb_text}-slots_{num_slots}-VGGNet_D256-' + \
                   f'len_{args.train_len}-bs_{batch_size}-lr_{args.lr}'
    # os.makedirs(args.output_path+str(dt_string)+'/logs/', exist_ok=True)
    # os.makedirs(args.output_path+str(dt_string)+'/models/', exist_ok=True)
    # os.makedirs(args.output_path+str(dt_string)+'/results/', exist_ok=True)

    logPath = os.path.join(args.output_path, dir_name, 'log')

    modelPath = os.path.join(args.output_path, dir_name, 'model')

    if inference:
        resultsPath = os.path.join(args.output_path, dir_name, 'results', args.resume_path.split('/')[-1])
        os.makedirs(resultsPath, exist_ok=True)
    else:
        os.makedirs(logPath, exist_ok=True)
        os.makedirs(modelPath, exist_ok=True)
        resultsPath = None

        # save all the expersetup_datasetiment settings.
        if args.rank == 0:
            with open('{}/running_command.txt'.format(modelPath), 'w') as f:
                json.dump(args.__dict__, f, indent=2)

    return [logPath, modelPath, resultsPath]


def setup_dataset(args):
    resolution = args.resolution  # h,w
    res = ""
    with_gt = True
    pairs = [1, 2, -1, -2]
    if args.dataset == 'DAVIS':
        
        basepath = args.basepath
        img_dir = basepath + '/JPEGImages/480p'
        gt_dir = basepath + '/Annotations/480p'


        val_flow_dir = basepath + '/Flows_gap1/480p'
        val_seq = ['dog', 'cows', 'goat', 'camel', 'libby', 'parkour', 'soapbox', 'blackswan', 'bmx-trees', 
                    'kite-surf', 'car-shadow', 'breakdance', 'dance-twirl', 'scooter-black', 'drift-chicane', 
                    'motocross-jump', 'horsejump-high', 'drift-straight', 'car-roundabout', 'paragliding-launch']            
        val_data_dir = [val_flow_dir, img_dir, gt_dir]
        res = "480p"

    elif args.dataset == 'FBMS':
        basepath = '/path/to/FBMS_clean'
        img_dir = '/path/to/FBMS_clean/JPEGImages/'
        gt_dir ='/path/to/FBMS_clean/Annotations/'    

        val_flow_dir = '/path/to/FBMS_val/Flows_gap1/'
        val_seq = ['camel01', 'cars1', 'cars10', 'cars4', 'cars5', 'cats01', 'cats03', 'cats06', 
                    'dogs01', 'dogs02', 'farm01', 'giraffes01', 'goats01', 'horses02', 'horses04', 
                    'horses05', 'lion01', 'marple12', 'marple2', 'marple4', 'marple6', 'marple7', 'marple9', 
                    'people03', 'people1', 'people2', 'rabbits02', 'rabbits03', 'rabbits04', 'tennis']
        val_img_dir = '/path/to/FBMS_val/JPEGImages/'
        val_gt_dir ='/path/to/FBMS_val/Annotations/'
        val_data_dir = [val_flow_dir, val_img_dir, val_gt_dir]
        with_gt = False
        pairs = [3, 6, -3, -6]

    elif args.dataset == 'STv2':
        basepath = '/path/to/SegTrackv2'
        img_dir = '/path/to/SegTrackv2/JPEGImages'
        gt_dir ='/path/to/SegTrackv2/Annotations'

        val_flow_dir = '/path/to/SegTrackv2/Flows_gap1/'
        val_seq = ['drift', 'birdfall', 'girl', 'cheetah', 'worm', 'parachute', 'monkeydog',
                    'hummingbird', 'soldier', 'bmx', 'frog', 'penguin', 'monkey', 'bird_of_paradise']
        val_data_dir = [val_flow_dir, img_dir, gt_dir]
    else:
        raise ValueError('Unknown Setting.')

    pair_list = [p for p in itertools.combinations(pairs , 2)]
    folders = [os.path.basename(x) for x in gb.glob(os.path.join(basepath, 'Flows_gap1/{}/*'.format(res)))]
    flow_dir = {}
    for pair in pair_list:
        p1, p2 = pair
        flowpairs = []
        for f in folders:
            path1 = os.path.join(basepath, 'Flows_gap{}/{}/{}'.format(p1, res, f))
            path2 = os.path.join(basepath, 'Flows_gap{}/{}/{}'.format(p2, res, f))

            flows1 = [os.path.basename(x) for x in gb.glob(os.path.join(path1, '*'))]
            flows2 = [os.path.basename(x) for x in gb.glob(os.path.join(path2, '*'))]

            intersect = list(set(flows1).intersection(flows2))
            intersect.sort()

            flowpair = np.array([[os.path.join(path1, i), os.path.join(path2, i)] for i in intersect])
            flowpairs += [flowpair]
        flow_dir['gap_{}_{}'.format(p1, p2)] = flowpairs

    # flow_dir is a dictionary, with keys indicating the flow gap, and each value is a list of sequence names,
    # each item then is an array with Nx2, N indicates the number of available pairs.
    data_dir = [flow_dir, img_dir, gt_dir]
    trn_dataset = FrameData(data_dir=data_dir, resolution=resolution, dataset=args.dataset, to_rgb=args.flow_to_rgb,
                            train=True, val_seq=None)
    val_dataset = FrameData(data_dir=val_data_dir, resolution=resolution, dataset=args.dataset, pair_list=pairs, 
                            to_rgb=args.flow_to_rgb, train=False, val_seq=val_seq)

    in_out_channels = 3 if args.flow_to_rgb else 2
    use_flow = False if args.flow_to_rgb else True
    loss_scale = 1e2
    ent_scale = 1e-2
    cons_scale = 1e-2
    
    return [trn_dataset, val_dataset, resolution, in_out_channels, use_flow, loss_scale, ent_scale, cons_scale]
