import os
from param import *
from utils import *
from spark_env.task import *
from spark_env.node import *
from spark_env.job_dag import *
import pickle
def load_job(file_path, query_size, query_idx, wall_time, np_random, job_count):
    query_path = file_path + query_size + '/'
    
    adj_mat = np.load(
        query_path + 'adj_mat_' + str(query_idx) + '.npy', allow_pickle=True)
    task_durations = np.load(
        query_path + 'task_duration_' + str(query_idx) + '.npy', allow_pickle=True).item()
    
    assert adj_mat.shape[0] == adj_mat.shape[1]
    assert adj_mat.shape[0] == len(task_durations)

    num_nodes = adj_mat.shape[0]
    nodes = []
    for n in range(num_nodes):
        task_duration = task_durations[n]
        e = next(iter(task_duration['first_wave']))

        num_tasks = len(task_duration['first_wave'][e]) + \
                    len(task_duration['rest_wave'][e])

        # remove fresh duration from first wave duration
        # drag nearest neighbor first wave duration to empty spots
        pre_process_task_duration(task_duration)
        rough_duration = np.mean(
            [i for l in task_duration['first_wave'].values() for i in l] + \
            [i for l in task_duration['rest_wave'].values() for i in l] + \
            [i for l in task_duration['fresh_durations'].values() for i in l])

        # generate tasks in a node
        tasks = []
        for j in range(num_tasks):
            task = Task(j, rough_duration, wall_time)
            tasks.append(task)

        # generate a node
        node = Node(n, tasks, task_duration, wall_time, np_random)
        nodes.append(node)

    # parent and child node info
    for i in range(num_nodes):
        for j in range(num_nodes):
            if adj_mat[i, j] == 1:
                nodes[i].child_nodes.append(nodes[j])
                nodes[j].parent_nodes.append(nodes[i])

    # initialize descendant nodes
    for node in nodes:
        if len(node.parent_nodes) == 0:  # root
            node.descendant_nodes = recursive_find_descendant(node)

    # generate DAG
    job_dag = JobDAG(nodes, adj_mat,
        args.query_type + '-' + query_size + '-' + str(query_idx) + '-' + str(job_count))

    return job_dag


def pre_process_task_duration(task_duration):
    # remove fresh durations from first wave
    clean_first_wave = {}
    for e in task_duration['first_wave']:
        clean_first_wave[e] = []
        fresh_durations = SetWithCount()
        # O(1) access
        for d in task_duration['fresh_durations'][e]:
            fresh_durations.add(d)
        for d in task_duration['first_wave'][e]:
            if d not in fresh_durations:
                clean_first_wave[e].append(d)
            else:
                # prevent duplicated fresh duration blocking first wave
                fresh_durations.remove(d)

    # fill in nearest neighour first wave
    last_first_wave = []
    for e in sorted(clean_first_wave.keys()):
        if len(clean_first_wave[e]) == 0:
            clean_first_wave[e] = last_first_wave
        last_first_wave = clean_first_wave[e]

    # swap the first wave with fresh durations removed
    task_duration['first_wave'] = clean_first_wave


def recursive_find_descendant(node):
    if len(node.descendant_nodes) > 0:  # already visited
        return node.descendant_nodes
    else:
        node.descendant_nodes = [node]
        for child_node in node.child_nodes:  # terminate on leaves automatically
            child_descendant_nodes = recursive_find_descendant(child_node)
            for dn in child_descendant_nodes:
                if dn not in node.descendant_nodes:  # remove dual path duplicates
                    node.descendant_nodes.append(dn)
        return node.descendant_nodes


def generate_alibaba_jobs(np_random, timeline, wall_time):
    pass

def generate_tpch_jobs(np_random, timeline, wall_time):
  

    
    job_dags = OrderedSet()
    t = 0
    job_count = 0
    
    # Build storage file path
    trace_dir = args.trace_dir
    os.makedirs(trace_dir, exist_ok=True)
    trace_file = f"{trace_dir}/{args.num_init_dags}_{args.num_stream_dags}_{args.first_job_query_size}_{args.first_job_query_idx}_{args.second_job_query_size}_{args.second_job_query_idx}.pkl"
    
    # Data structure to store generation information for each job
    job_params = {}
    
    if not args.deterministic or not os.path.exists(trace_file):
        # Prepare two predefined job types
        predefined_jobs = [
            (args.first_job_query_size, args.first_job_query_idx),
            (args.second_job_query_size, args.second_job_query_idx)
        ]
        
        # Generate parameters for initial DAGs
        init_job_params = []
        for i in range(args.num_init_dags):
            # Cycle through predefined jobs based on index
            job_type_idx = i % len(predefined_jobs)
            query_size, query_idx = predefined_jobs[job_type_idx]
            init_job_params.append((query_size, query_idx, 0))  # Save size, idx and timestamp
            
        # Generate parameters for streaming DAGs
        stream_job_params = []
        stream_t = 0
        for i in range(args.num_stream_dags):
            # Poisson process
            interval = int(np_random.exponential(args.stream_interval))
            stream_t += interval
            
            # Cycle through predefined jobs based on index
            job_type_idx = i % len(predefined_jobs)
            query_size, query_idx = predefined_jobs[job_type_idx]
            
            stream_job_params.append((query_size, query_idx, stream_t))  # Save size, idx and timestamp
            
        job_params = {"init": init_job_params, "stream": stream_job_params}
        
        # Save the generated parameters to file
        if args.deterministic:
            with open(trace_file, 'wb') as f:
                pickle.dump(job_params, f)
                print(f"Successfully saved experiment trace to {trace_file}")
    else:
        # Read from file mode
        with open(trace_file, 'rb') as f:
            job_params = pickle.load(f)
            print(f"Successfully read experiment trace from {trace_file}")
    
    # Generate initial DAGs based on parameters
    for query_size, query_idx, _ in job_params["init"]:
        job_dag = load_job(
            args.job_folder, query_size, query_idx, wall_time, np_random, job_count)
        job_dag.start_time = t
        job_dag.arrived = True
        job_dags.add(job_dag)
        job_count += 1
    
    # Generate streaming DAGs based on parameters
    for query_size, query_idx, stream_t in job_params["stream"]:
        job_dag = load_job(
            args.job_folder, query_size, query_idx, wall_time, np_random, job_count)
        job_dag.start_time = stream_t
        timeline.push(stream_t, job_dag)
        job_count += 1
    
    return job_dags


def generate_jobs(np_random, timeline, wall_time):
    if args.query_type == 'tpch':
        job_dags = generate_tpch_jobs(np_random, timeline, wall_time)

    elif args.query_type == 'alibaba':
        job_dags = generate_alibaba_jobs(np_random, timeline, wall_time)

    else:
        print('Invalid query type ' + args.query_type)
        exit(1)

    return job_dags