from datasets.cater.enums import SHAPES, actions_order_dataset, reverse_rel, reverse
import numpy as np
from collections import defaultdict
from itertools import permutations

import json
import os
from pathlib import Path
from tqdm import tqdm
import random

SLIDE_LEN = 6
ROTATE_LEN = 4
PICK_PLACE_LEN = 5
CONTAIN_LEN = 5

durations = {'_slide': SLIDE_LEN, '_rotate': ROTATE_LEN, '_pick_place': PICK_PLACE_LEN, '_contain': CONTAIN_LEN}

# variance of our estimates from our vision detection of the atomic events, use these to add noise to the simulate data
uncertainty = {'_slide': (.88, .1), '_rotate': (.76, .15), '_pick_place': (.95, .07), '_contain': (.95, .07)}

# timeline goes from [0, 300]
LEN = 300
MIN_START_TIME = SLIDE_LEN
MAX_END_TIME = LEN - SLIDE_LEN
MAX_SUB_END_TIME = LEN - 2*SLIDE_LEN

FOLDER_FORMAT = '{}_len_{}_comp_{}_samples_{}_beam'
TRAIN_FILE = 'train.json'
VAL_FILE = 'val.json'
TEST_FILE = 'test.json'

def generate_rules(n_predicates=1, variable_len=True, max_rules_beam=100):
    """
    n_predicates: Generate rules of max legnth n_predicates
    variable_len: If true create rules up to n_predicates, else rules only of n_predicates
    max_rules_beam: max rules to generate at every rule length
    """
    insert_predicates = actions_order_dataset()
    predicate_beam = insert_predicates[:max_rules_beam]

    # can use trie data structure to be more efficient, but here we brute force generate
    rules = {((pred), ) : i for i, pred in enumerate(predicate_beam)}
    updated_rules = rules.copy()
    for iteration in range(n_predicates - 1):
        # print(len(rules))
        iteration_beam = 0
        base_rules = list(rules.keys())
        base_rules.reverse() # select the most recent rules to build new rules, dicts in py3.6+ are ordered now
        for rule in base_rules:
            for predicate in predicate_beam:
                predicates = list(rule)
                if predicate in predicates or reverse(predicate) in predicates:
                    continue
                predicates.append(predicate)
                new_rule = tuple(sorted(predicates))  # keep a consistent order, for checking
                if new_rule not in updated_rules:
                    updated_rules[new_rule] = len(updated_rules)
                    iteration_beam += 1
                if iteration_beam >= max_rules_beam:
                    break
            if iteration_beam >= max_rules_beam:
                break

        rules = updated_rules.copy()

    if not variable_len:
        # only choose rules with the n_predicate len
        # rules = {k: v for k, v in list(rules.items())[-max_rules_beam:]}
        rules = {rule: idx for idx, rule in enumerate(list(rules)[-max_rules_beam:])}
    print(f'max rule len: {n_predicates}, rule_beam_len: {max_rules_beam}, number of rules: {len(rules)}') 
    return rules

def sample_movement(relation, sub_movement, obj_movement):
    if relation == 'before':
        (sub_start_time, sub_end_time), (obj_start_time, obj_end_time) = sample_before(sub_movement, obj_movement)
    elif relation == 'after':
        # just like before, but in the reverse order subj after object
        (obj_start_time, obj_end_time), (sub_start_time, sub_end_time) = sample_before(obj_movement, sub_movement)
    elif relation == 'during':
        (sub_start_time, sub_end_time), (obj_start_time, obj_end_time) =  sample_during(sub_movement, obj_movement)
    else:
        raise ValueError(f'unknown relation {relation}')
    
    check_min = min(sub_start_time, sub_end_time, obj_start_time, obj_end_time)
    check_max = max(sub_start_time, sub_end_time, obj_start_time, obj_end_time)
    assert check_min >= 0, print(check_min)
    assert check_max <= LEN, print(check_max)

    return (sub_start_time, sub_end_time), (obj_start_time, obj_end_time)

def sample_before(sub_movement, obj_movement):
    sub_start_time = np.random.randint(0, MAX_SUB_END_TIME - 1)
    sub_end_time = sub_start_time + durations[sub_movement]

    obj_start_time = np.random.randint(sub_end_time + 1, MAX_END_TIME)
    obj_end_time = obj_start_time + durations[obj_movement]

    assert sub_end_time < obj_start_time

    return (sub_start_time, sub_end_time), (obj_start_time, obj_end_time)

def sample_during(sub_movement, obj_movement):
    sub_start_time = np.random.randint(MIN_START_TIME, MAX_END_TIME)
    sub_end_time = sub_start_time + durations[sub_movement]

    obj_dur = durations[obj_movement]
    obj_start_time = np.random.randint(sub_start_time - obj_dur + 1, min(sub_end_time - 1, MAX_END_TIME))
    obj_end_time = obj_start_time + obj_dur

    assert sub_start_time < obj_end_time and obj_start_time < sub_end_time

    return (sub_start_time, sub_end_time), (obj_start_time, obj_end_time)

def generate_predictions(start, end, movement=None, noise=True):
    duration = end - start
    if noise:
        mean, std = uncertainty[movement]
        samples = np.random.normal(mean, std, size=duration)
        samples = np.clip(samples, 0, 1)
        timeline = np.random.normal(1e-5, 3e-6, size=(LEN + 1))
        timeline = np.clip(timeline, 0, 1)
    else:
        samples = np.ones(duration)
        timeline = np.zeros(LEN + 1)
    timeline[start: end] = samples        
    return timeline

def flatten_ts(timeline):
    ts = []
    for shape, events in timeline.items():
        for event in events:
            shape_id, action, start, end, probs = event
            ts.append(((shape, action), start, end))
    ts = sorted(ts, key=lambda event: event[1])
    return ts

def get_relation(subj_start, subj_end, obj_start, obj_end):
    if subj_end <= obj_start:
        return 'before'
    elif obj_end <= subj_start:
        return 'after'
    elif subj_start < obj_end and obj_start < subj_end:
        return 'during'
    else:
        raise RuntimeError(f'Invalid times f{subj_start, subj_end, obj_start, obj_end}')

def find_labels(timeline, rules, split):
    ts = flatten_ts(timeline)
    predicates = []
    for subj_idx in range(len(ts)):
        for obj_idx in range(subj_idx + 1, len(ts)):
            subj = ts[subj_idx]
            obj = ts[obj_idx]
            subj_event, subj_start, subj_end = subj
            obj_event, obj_start, obj_end = obj
            relation = get_relation(subj_start, subj_end, obj_start, obj_end)
            # all rules consist of some tuple of predicates, even len-1 predicates
            predicate = (((subj_event, obj_event), tuple([relation]))),
            rev_predicate = (((obj_event, subj_event), tuple([reverse_rel(relation)]))),
            predicates.extend([predicate, rev_predicate])
    
    all_events = predicates.copy()
    latent_labels = []
    # strip predicate tuple format before testing combinations
    base_predicates = set([p[0] for p in predicates])
    for rule in rules.keys():
        for predicate in rule:
            if predicate not in base_predicates:
                break
        else:
            if split == 'train':
                # if np.random.rand() < .25:
                latent_labels.append(rules[rule])
            else:
                latent_labels.append(rules[rule])

    return list(set(latent_labels))  # return the unique ones

def add_timeline_noise(timeline):
    for shape, objects in timeline.items():
        for obj_num in range(len(objects)):
            obj_id, movement, start, end, probs = objects[obj_num]
            temporal_noise = int(np.random.randint(0, 5))
            new_start = start + temporal_noise
            new_end = end + temporal_noise
            if new_start >= 0 and new_end <= LEN:
                objects[obj_num] = (obj_id, movement, new_start, new_end, probs)

    return timeline


MAX_SHAPE = 6
def sample_shape(shape, shape_ids):
    ids = shape_ids[shape]
    if len(ids) == MAX_SHAPE:
        return random.choice(ids)
    else:
        new_id = f'{shape}_{len(ids)}'
        shape_ids[shape].append(new_id)
        return new_id

def generate_ts(rule_len, rules, split, j_samples=3, m_ts=50000, noise=True):
    # index rules by label for easy sampling
    rule_idx = {v: k for k, v in rules.items()}
    timelines = []
    num_rules = []
    for ts_index in tqdm(range(m_ts)):
        timeline = defaultdict(list)
        rule_label_samples = np.random.randint(0, high=len(rule_idx), size=j_samples).tolist()
        rule_samples = [rule_idx[label] for label in rule_label_samples]
        shape_ids = {s:[] for s in SHAPES}
        # find consistent samples
        for rule in rule_samples:
            for predicate in rule:
                (sub_action, obj_action), (relation, ) = predicate
                sub_shape, sub_movement = sub_action
                obj_shape, obj_movement = obj_action

                (sub_start_time, sub_end_time), (obj_start_time, obj_end_time) = sample_movement(relation, sub_movement, obj_movement)
                sub_predictions = generate_predictions(sub_start_time, sub_end_time, movement=sub_movement, noise=noise)
                obj_predictions = generate_predictions(obj_start_time, obj_end_time, movement=obj_movement, noise=noise)

                # print(predicate)
                # print((sub_start_time, sub_end_time), (obj_start_time, obj_end_time))
                # print(sub_predictions, obj_predictions)

                sub_shape_id = sample_shape(sub_shape, shape_ids)
                obj_shape_id = sample_shape(obj_shape, shape_ids)

                timeline[sub_shape].append((sub_shape_id, sub_movement, sub_start_time, sub_end_time, sub_predictions.tolist()))
                timeline[obj_shape].append((obj_shape_id, obj_movement, obj_start_time, obj_end_time, obj_predictions.tolist()))

        # now given the generated timeline, there are composite events that were inherently generated by the combination of the sampled
        # comp events, so we have to identify these and add to the label set.
        inherent_labels = find_labels(timeline, rules, split)
        # inherent_labels = rule_label_samples
        # timeline = add_timeline_noise(timeline)
        
        # due to the exhaustive nature of the search, it has to a superset of the original generative rules
        # assert set(rule_label_samples).issubset(set(inherent_labels))
        # print(len(rule_label_samples), len(inherent_labels))
        num_rules.append(len(inherent_labels))
        timelines.append((timeline, inherent_labels))

    print(f'mean number of active labels {np.mean(num_rules)}')
    return timelines

def folder_format(rule_len, events, samples, max_rules_beam):
    return FOLDER_FORMAT.format(rule_len, events, samples, max_rules_beam)

def folder_path(rule_len, events, samples, max_rules_beam):
    dir_format = folder_format(rule_len, events, samples, max_rules_beam)
    return os.path.join(data_path, dir_format)
                
def get_path(split, data_path):
    if split == 'train':
        train_file_path = os.path.join(data_path, TRAIN_FILE)
    elif split == 'val':
        train_file_path = os.path.join(data_path, VAL_FILE)
    else:
        train_file_path = os.path.join(data_path, TEST_FILE)
    return train_file_path

def save_split(split, samples, data_path):
    file_path = get_path(split, data_path)
    print(f'saving to: {file_path}')
    with open(file_path, 'w') as f:
        json.dump(samples, f)

if __name__ == '__main__':
    # a = actions_order_dataset()

    import argparse
 
    parser = argparse.ArgumentParser(description='Data Generation')

    # training args
    parser.add_argument('--rule_len', type=int, default=1, help='number of predicates in the rule')
    parser.add_argument('--co_occurring_events', type=int, default=1, help='number of composite events to sample in each video')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--train_samples', type=int, default=10000)
    parser.add_argument('--val_samples', type=int, default=2500)
    parser.add_argument('--test_samples', type=int, default=2500)
    parser.add_argument('--no_noise', action='store_true', help='vary the probability distribution of the sampled events')
    parser.add_argument('--max_rules_beam', type=int, default=100, help='max rules to generate at every rule length')
    parser.add_argument('--fixed_len', action='store_true', help='only use rules of len rule_len')
    parser.add_argument('--gen_path', type=str, default='/localscratch/cater_dataset/generated_data/', help='directory where to store the data')

    args = parser.parse_args()
    rule_len = args.rule_len
    co_occurring_events = args.co_occurring_events
    train_samples = args.train_samples
    val_samples = args.val_samples
    test_samples = args.test_samples
    noise = not args.no_noise
    data_path = args.gen_path
    variable_len = not args.fixed_len
    max_rules_beam = args.max_rules_beam

    # co_occurring_events = 1
    # rule_len = 3
    # max_rules_beam = 33
    # train_samples = 1000


    random.seed(args.seed)
    np.random.seed(args.seed)

    # for train_samples in [10, 100, 500, 1000, 5000, 10000]:
    # for co_occurring_events in [1, 3, 5, 10, 15, 20]:
    # for rule_len in range(1, 11):

    format = folder_format(rule_len, co_occurring_events, train_samples, max_rules_beam)
    sim_data_path = os.path.join(data_path, format)
    Path(sim_data_path).mkdir(parents=True, exist_ok=True)

    rules = generate_rules(n_predicates=rule_len, variable_len=variable_len, max_rules_beam=max_rules_beam)

    train_data = generate_ts(rule_len, rules, 'train', j_samples=co_occurring_events, m_ts=train_samples, noise=noise)
    save_split('train', train_data, sim_data_path)

    val_data = generate_ts(rule_len, rules, 'val', j_samples=co_occurring_events, m_ts=val_samples, noise=noise)
    save_split('val', val_data, sim_data_path)

    test_data = generate_ts(rule_len, rules, 'test', j_samples=co_occurring_events, m_ts=test_samples, noise=noise)
    save_split('test', test_data, sim_data_path)
