import os
import random

import numpy as np
import pickle
import json
from tqdm import tqdm
from config_files.data_paths import *

import argparse

parser = argparse.ArgumentParser(description="Generate random segment configurations.")
parser.add_argument('--dataset_dir', type=str, required=True, help='Path to the dataset directory')
parser.add_argument('--dataset_name', type=str, default='interhuman', help='Name of the dataset')
parser.add_argument('--save_dir', type=str, default='./evaluation/dataset', help='Directory to save the generated JSON config')
parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility')
parser.add_argument('--interaction', action='store_true', help='If set, process data for two-person interaction')
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)

raw_dataset_path = args.dataset_dir
target_fps = int(raw_dataset_path[-2:])

spl = 'val'
with open(Path(raw_dataset_path) / f'{spl}.pkl', 'rb') as f:
    data = pickle.load(f)

seq_cfg_list = []
for idata in tqdm(data):
    sid = idata['seq_name']
    # get frame labels
    frame_labels = {} if args.interaction else []
    if args.interaction:
        for key in ['person1', 'person2', 'interaction']:
            frame_labels[key] = idata[f'frame_labels_{key}']
    else:
        frame_labels = idata['frame_labels']

    # sort frame labels first by start time, then by end time
    def sort_time_points(frame_labels):
        time_points = []
        for seg in frame_labels:
            time_points.append(seg['start_t'])
            time_points.append(seg['end_t'])
        time_points = sorted(list(set(time_points)))
        max_interval = 200 / target_fps
        split_points = []
        for idx in range(len(time_points) - 1):
            split_point = time_points[idx] + max_interval
            while split_point < time_points[idx + 1]:
                split_points.append(split_point)
                split_point += max_interval
        time_points += split_points
        time_points = sorted(list(set(time_points)))
        return time_points
    
    if args.interaction:
        time_points = sort_time_points(frame_labels['interaction'])
    else:
        time_points = sort_time_points(frame_labels)

    seq_cfg = {
        'id': sid,
        "scenario": "in-distribution",
        'text': [],
        'lengths': [],
    }
    for idx in range(len(time_points) - 1):
        start_t = time_points[idx]
        end_t = time_points[idx + 1]
        num_frames = int((end_t - start_t) * target_fps)
        if num_frames < 0.2 * target_fps:  # ignore too short segments, annotator might mean the same frame, but the clicks are not accurate
            continue
        def get_texts(frame_labels):
            texts = []
            for seg in frame_labels:
                if seg['proc_label'] == 'transition':  # ignore transition
                    continue
                overlap_time = min(end_t, seg['end_t']) - max(start_t, seg['start_t'])
                if overlap_time > 1e-6:
                    proc_label = seg['proc_label']
                    texts.append(proc_label)
            if len(texts) == 0:
                return None
            print(sid, start_t, end_t, texts)
            compo_text = ' and '.join(texts)
            random_text = random.choice(texts)
            return random_text
            
        if args.interaction:
            random_text = {}
            for key in ['person1', 'person2', 'interaction']:
                random_text[key] = get_texts(frame_labels[key])
            if random_text['interaction'] is None:
                continue
        else:
            random_text = get_texts(frame_labels)
            if random_text is None:
                continue
        seq_cfg['text'].append(random_text)
        seq_cfg['lengths'].append(max(num_frames, 15))  # at least 15 frames, compatible with flowmdm
    if len(seq_cfg['text']) == 0:
        print(sid, 'no valid segment found')
        continue
    if len(seq_cfg['text']) == 1:
        print(sid, 'not valid for priormdm')
        continue
    print(seq_cfg)
    seq_cfg_list.append(seq_cfg)

random.shuffle(seq_cfg_list)
with open(Path(args.save_dir)/f'{args.dataset_name}_seed{args.seed}.json', 'w') as f:
    json.dump(seq_cfg_list[:], f, indent=4)
