# Load either a python or json file, run microcodecs to generate a mesh and simulation results,
# check that the material is valid, render the material, then generate descriptions and
# task data.
# Task data generation is probably best left in a seperate process, since we may wish to 
# vary this. The data and categorizations, however, may not be best to split up.


# For the material descriptions, it may actually be best to talk about the material properties
# and the high-level program stuff (e.g. trusses, etc.), rather than the rather-generic image
# descriptors.

from .util import *
from .validation import *
from .rendering import *
import json
import os
import pathspec
import yaml
import subprocess
import argparse
import resource
from tempfile import TemporaryDirectory

import logging
import sys
import time


def get_sim_results(sim_data:dict, round:bool=True, clamp_to_zero:bool=True):
    # sim_data: contents of structure_info.json
    # return: a list of sim results

    S = np.array(sim_data["sim_S_matrix"], dtype=float)

    # helpful because conventions are all 1-indexed. This way formulas match literature.
    def get_Sij(i:int, j:int, ijOneIndexed:bool=True):
        if ijOneIndexed:
            return S[i-1][j-1]
        return S[i][j]

    # seed this with values directly from structure_info
    gt_sim_results = {
        'E':float(sim_data['sim_E_VRH']), # average young's modulus
        'K':float(sim_data['sim_K_VRH']), # average bulk modulus
        'G':float(sim_data['sim_G_VRH']), # average shear modulus
        'nu':float(sim_data['sim_nu_VRH']), # average poisson ratio
        'A':float(sim_data['sim_A_UAI']), # universal anisotropy index (0 is perfectly isotropic, more positive = more anisotropy)
        'V':float(sim_data['thickened_occupied_volume_fraction']) # occupied volume fraction
    }

    # add in the extra axis-specific values
    gt_sim_results["E_1"] = 1.0 / get_Sij(1,1) if get_Sij(1,1) != 0 else float('nan')
    gt_sim_results["E_2"] = 1.0 / get_Sij(2,2) if get_Sij(2,2) != 0 else float('nan')
    gt_sim_results["E_3"] = 1.0 / get_Sij(3,3) if get_Sij(3,3) != 0 else float('nan')
    gt_sim_results["G_23"] = 1.0 / get_Sij(4,4) if get_Sij(4,4) != 0 else float('nan') 
    gt_sim_results["G_31"] = 1.0 / get_Sij(5,5) if get_Sij(5,5) != 0 else float('nan') 
    gt_sim_results["G_12"] = 1.0 / get_Sij(6,6) if get_Sij(6,6) != 0 else float('nan') 
    gt_sim_results["nu_12"] = get_Sij(2,1) * -gt_sim_results["E_1"]
    gt_sim_results["nu_13"] = get_Sij(3,1) * -gt_sim_results["E_1"]
    gt_sim_results["nu_23"] = get_Sij(3,2) * -gt_sim_results["E_2"]
    gt_sim_results["nu_21"] = get_Sij(1,2) * -gt_sim_results["E_2"]
    gt_sim_results["nu_31"] = get_Sij(1,3) * -gt_sim_results["E_3"]
    gt_sim_results["nu_32"] = get_Sij(2,3) * -gt_sim_results["E_3"]

    # clamp to 0 if abs() < 1e-8
    if clamp_to_zero:
        tol = 1e-8
        for key in gt_sim_results.keys():
            if abs(gt_sim_results[key]) < tol:
                gt_sim_results[key] = 0.0

    if round:
        delineations = {"default": 0.01} # if you want specific delineations per-property, add to dictionary. Others will use default value
        for key in gt_sim_results.keys():
            if isinstance(gt_sim_results[key], float):
                increment = delineations[key] if key in delineations else delineations["default"]
                gt_sim_results[key] = round_to_delineation(gt_sim_results[key], increment)

    return gt_sim_results

def batch_process_programs(
        src: str,  
        dst: str,
        fail_dst: str,
        src_fmt: str = "*.json", 
        src_exceptions: list[str] = [],
        resolution: int = 100, 
        worker_id: int = 0, 
        n_workers: int = 1,
        memory_limit: int = None,
        timeout: int = None,
        name_is_parent_dir: bool = False):

    # Setup Logging    
    logging.basicConfig(stream=sys.stderr, level=logging.INFO)
    
    # Normalize paths to end with / and be absolute if local
    if not src.startswith('s3://'):
        src = os.path.abspath(src)
    if not is_file(src):
        src = src.rstrip('/') + '/'
    if not dst.startswith('s3://'):
        dst = os.path.abspath(dst)
    dst = dst.rstrip('/') + '/'
    if not fail_dst.startswith('s3://'):
        fail_dst = os.path.abspath(fail_dst)
    fail_dst = fail_dst.rstrip('/') + '/'

    logging.info(f'Worker {worker_id} / {n_workers} starting.')
    logging.info(f'Source = {src}')
    logging.info(f'Destination = {dst}')
    logging.info(f'Fail Dest. = {fail_dst}')

    logging.info(f'Listing files from {src}')
    all_files = list_all(src)
    logging.info(f'Found {len(all_files)} files')

    if is_file(src):
        src_fmt = [src_fmt] if isinstance(src_fmt, str) else src_fmt
        src_fmt = [path_append('**/',pattern) for pattern in src_fmt]
        src_exceptions = [path_append('**/', pattern) for pattern in src_exceptions]
    else:    
        src_fmt = [src_fmt] if isinstance(src_fmt, str) else src_fmt
        src_fmt = [path_append(src + '**/',pattern) for pattern in src_fmt]
        src_exceptions = [path_append(src + '**/', pattern) for pattern in src_exceptions]


    fmt_spec = pathspec.PathSpec.from_lines('gitwildmatch', src_fmt)
    exception_spec = pathspec.PathSpec.from_lines('gitwildmatch', src_exceptions)

    logging.info('Filtering files')
    selected_programs = exception_spec.match_files(fmt_spec.match_files(all_files), negate=True)
    worker_programs = [p for i,p in enumerate(selected_programs) if i % n_workers == worker_id]
    logging.info(f'Worker has {len(worker_programs)} programs to process')

    worker_tuples = [
        (
            k, # src uri
            path_append(dst, k[len(src):]), # success uri
            path_append(fail_dst, k[len(src):]), # failure uri
        )
        for k in worker_programs
    ]
    if name_is_parent_dir:
        # Success and Failure directories are named like the program's parent dir
        worker_tuples = [
            (
                s,
                '/'.join(d.split('/')[:-1]) if '/' in d else d,
                '/'.join(f.split('/')[:-1]) if '/' in f else f
            )
            for s, d, f in worker_tuples
        ]
    else:
        # Success and Failure directories are named like the program
        worker_tuples = [
            (
                s,
                '.'.join(d.split('.')[:-1]) if '.' in d else d,
                '.'.join(f.split('.')[:-1]) if '.' in f else f
            )
            for s, d, f in worker_tuples
        ]

    num_worker_tasks = len(worker_tuples)
    
    time_per_task = timeout # Switch to per-task timeout
    #if num_worker_tasks > 0:
    #    time_per_task = int(timeout / num_worker_tasks) if timeout is not None else None
    #else:
    #    time_per_task = timeout

    for i, (src, dst_success, dst_failure) in enumerate(worker_tuples):
        log_data = {
            'worker_id': worker_id,
            'n_workers': n_workers,
            'worker_task_idx': i,
            'worker_num_tasks': num_worker_tasks,
            'src': src,
            'dst_success': dst_success,
            'dst_failure': dst_failure,
            'resolution': resolution,
            'job_time_limit': time_per_task,
            'job_memory_limit': memory_limit
        }
        logging.info(f'Processing {i} / {num_worker_tasks}')
        logging.info(json.dumps(log_data, indent=2))
        process_material(src, dst_success, dst_failure, resolution, log_data, memory_limit, time_per_task)

def run_batch_process_programs():
    parser = argparse.ArgumentParser()
    parser.add_argument('-s', '--source', type=str, default='s3://metagen-datasets/seeds/graph_seeds')
    parser.add_argument('-d', '--dest', type=str, default='s3://metagen-datasets/data/graph_v1/processed/seeds')
    parser.add_argument('-f', '--fail', type=str, default='s3://metagen-datasets/data/graph_v1/failed/seeds')
    parser.add_argument('-r', '--resolution', type=int, default=100)
    parser.add_argument('-p', '--pattern', nargs='+', default=['*.json'])
    parser.add_argument('-e', '--exceptions', nargs='+', default=['_maybe/', '**/fully_occupied_cube.json', '_invalid/'])
    parser.add_argument('-m', '--memory_limit', type=float, default=None, help='maximum memory usage in GB')
    parser.add_argument('-t', '--timeout', type=int, default=None, help='timeout per worker in seconds')
    parser.add_argument('-n', '--name_is_parent_dir', action='store_true', default=False, help='Use this flag if the source programs are named by their parent folder')
    args = parser.parse_args()

    n_workers = int(os.getenv('AWS_BATCH_JOB_ARRAY_SIZE', '1'))
    worker_id = int(os.getenv('AWS_BATCH_JOB_ARRAY_INDEX', '0'))

    memory_limit_bytes = int(10**9 * args.memory_limit) if args.memory_limit is not None else None


    batch_process_programs(
        args.source,
        args.dest,
        args.fail,
        args.pattern,
        args.exceptions,
        args.resolution,
        worker_id,
        n_workers,
        memory_limit_bytes,
        args.timeout,
        args.name_is_parent_dir
    )

def run_process_material():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', type=str)
    parser.add_argument('-o', type=str)
    parser.add_argument('-r', type=int, default=100)

    args = parser.parse_args()
    input = args.i
    output = args.o
    resolution = args.r
    process_material(input, output, '/dev/null', resolution)


def code_to_graph(code):
    """
    Run a python DSL program and extract the resultant JSON
    TODO - acutal validation and extraction
    """
    env = {}
    try:
        exec(code, env)
    except Exception as e:
        pass
    return env['graph']

def code_to_kernel(graph, tmp):
    program_path = os.path.join(tmp, 'graph.json')
    if isinstance(graph, str) and os.path.exists(graph):
        pass
    elif isinstance(graph, str):
        pass
    elif isinstance(graph, dict):
        with open(program_path,'w', encoding='utf-8') as f:
            json.dump(graph, f)


def run_program(program_uri: str, output_path: str) -> bool:
    """
    Runs a DSL program and saves the output json to output_path
    
    Note: Program uri can be local or remote, but output_path
    must be local because ProcMetaTranslator.save uses local
    paths only. This could be changed in the future by switching
    to string or dict output and handling writing here.
    """
    code = read_text_file(program_uri)
    try:
        code = f"from metagen import *\n{code}\noutput = ProcMetaTranslator(make_structure())"
        code_env = {}
        exec(code, code_env)
        code_env['output'].save(output_path)
    except  Exception as e: # TODO - can / should this exception be more specific?
        return False
    
    if is_file(output_path):
        return True # TODO - do more validation here using the type system once implemented
    return False

def process_material(
        input: str,
        dst_success: str,
        dst_failure: str,
        resolution: int,
        log_data: dict = {},
        memlimit: int = None,
        timeout: int = None,
        color = None
        ) -> bool:
    # Skip Already Processed Materials
    if directory_exists(dst_success):
        logging.info('Skipping (previous success)')
        return True
    if directory_exists(dst_failure):
        retry_with_more_time = False
        retry_with_more_memory = False
        fail_info = path_append(dst_failure, 'processing_info.json')
        if is_file(fail_info):
            with open_file(fail_info) as f:
                info = json.load(f)
            logging.info('Previous Failure: Checking for Retry')
            if 'gen_and_sim_reason' in info and info['gen_and_sim_reason'] == 'timeout':
                if 'job_time_limit' in info and (timeout is None or info['job_time_limit'] < timeout):
                    logging.info('Retrying with more time')
                    retry_with_more_time = True
            if memlimit is not None:
                if 'gen_and_sim_likely_oom' in info and 'job_memory_limit' in info:
                    failed_memory_limit = info['job_memory_limit']
                    if failed_memory_limit is not None and info['gen_and_sim_likely_oom']:
                        if failed_memory_limit < memlimit:
                            retry_with_more_memory = True
        if retry_with_more_memory or retry_with_more_time:
            logging.info('Retrying previous failure with more memory or time')
            remove_dir(dst_failure)
        else:
            logging.info('Skipping (previous failure)')
            return False
    
    with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
        processing_info = log_data.copy()
        
        # Final Status Variables
        success = True
        reason = 'success'

        # Temporary File Locations
        tmp_graph = path_append(tmp, 'graph.json')
        tmp_code = path_append(tmp, 'code.py')
        tmp_description = path_append(tmp, 'description.txt')
        tmp_parents = path_append(tmp, 'parents.txt')
        tmp_vox = path_append(tmp, 'vox_active_cells.txt')
        tmp_vox_surface = path_append(tmp, 'vox_surface.obj')
        tmp_mesh = path_append(tmp, 'thickened_mc.obj')
        tmp_info = path_append(tmp, 'processing_info.json')
        tmp_sim = path_append(tmp, 'structure_info.json')

        # Input Paths to check
        input_stem = '.'.join(input.split('.')[:-1])
        input_description = input_stem + '.txt'
        input_parents = input_stem + '.parents'
        
        # Check What info we have
        program_is_code = input.lower().endswith('.py')
        has_description = is_file(input_description)
        has_parents = is_file(input_parents)

        # Copy Files that we do have
        if has_description:
            copy_file(input_description, tmp_description)
        if has_parents:
            copy_file(input_parents, tmp_parents)
        if program_is_code:
            copy_file(input, tmp_code)
        else:
            copy_file(input, tmp_graph)

        for _ in (True,): # Hack to have a block to break out of
            
            # Run the DSL Code if it exists
            if program_is_code:
                success = run_program(tmp_code, tmp_graph)
                if not success:
                    reason = 'code'
                    break
            
            # Generate and Simulate from Graph
            gen_results = generate_and_simulate(tmp_graph, resolution, tmp, mem_limit=memlimit, timeout=timeout)
            success = gen_results['gen_and_sim_success']
            processing_info = merge_defaults(processing_info, gen_results)
            if not success:
                reason = 'graph'
                break
            
            # Render Material
            run_render(tmp_mesh, tmp, color)

            # Validate Material
            with open(tmp_vox, 'r') as f:
                voxels = load_voxels(f)
            success = validate_voxels(voxels)
            processing_info['valid'] = success
            if not success:
                reason = 'validation'
                break
            success = validate_simulation(tmp_sim)
            if not success:
                reason = 'simulation'
                break
            
            # Check if Clone
            if has_parents:
                parents = [p.strip() for p in read_text_file(tmp_parents).strip().split('\n')]
                cloned_parents = renders_equal(tmp, *parents)
                if any(cloned_parents):
                    success = False
                    reason = 'clone'
                    processing_info['cloned_parents'] = [p for p,clone in zip(parents, cloned_parents) if clone]
            
        # Write summary of processing
        processing_info['success'] = success
        processing_info['reason'] = reason
        with open(tmp_info, 'w') as f:
            json.dump(processing_info, f)
        
        # Remove large voxel file before copying
        if os.path.isfile(tmp_vox_surface):
            os.remove(tmp_vox_surface)

        if success:
            copy_dir(tmp, dst_success)
        else:
            copy_dir(tmp, dst_failure)


def format_time(seconds):
    if seconds > 60:
        seconds = round(seconds)
        minutes = int(seconds/60)
        seconds = seconds - minutes * 60
        if minutes > 60:
            hours = int(minutes / 60)
            minutes = minutes - hours*60
            return f'{hours} hours, {minutes} minutes, {seconds} seconds'
        else:
            return f'{minutes} minutes, {seconds} seconds'
    else:
        return f'{seconds:.4} seconds'
        

def generate_and_simulate(graph_path: str, resolution: int, outdir: str, mem_limit: int = None, timeout: int = None, simulate: bool = True):
    """
    Runs the metamaterial kernel to construct and simulate the material.

    This only checks that the code runs and produces all outputs, it does
    not validate those outputs.
    """

    def limit_virtual_memory():
        if mem_limit is not None:
            resource.setrlimit(resource.RLIMIT_AS, (mem_limit, mem_limit))

    outdir = outdir.rstrip('/') + '/'
    success = True
    reason = 'success'
    likely_oom = False
    log_info = {}
    try:
        logging.info(f'executing "evaluateJSONMetagen {graph_path} {resolution} {outdir}"')
        start_time = time.time()
        if simulate:
            p = subprocess.run([
                "evaluateJSONMetagen",
                graph_path,
                str(resolution),
                outdir],
                capture_output=True,
                text=True,
                preexec_fn=limit_virtual_memory,
                timeout=timeout
            )
        else:
            p = subprocess.run([
                "evaluateJSONMetagen",
                graph_path,
                str(resolution),
                outdir, '0'],
                capture_output=True,
                text=True,
                preexec_fn=limit_virtual_memory,
                timeout=timeout
            )
        execution_time = time.time() - start_time
        formatted_time = format_time(execution_time)
        logging.info(f'evaluateJSONMetagen exited with code {p.returncode} in {formatted_time}')

        success = (p.returncode == 0)
        if not success:
            reason = 'nonzero_exit_code'

        log_info['gen_and_sim_return_code'] = p.returncode
        log_info['gen_and_sim_stdout'] = p.stdout
        log_info['gen_and_sim_stderr'] = p.stderr
        log_info['gen_and_sim_time_seconds'] = execution_time
        log_info['gen_and_sim_time_fmt'] = formatted_time

        # Check for likely OOM
        if p.returncode == -6:
            likely_oom = True # Based on observations on WSL
            logging.info('evaluateJSONMetagen likely OOM (code -6 SIGABRT)')
        if p.returncode == -11:
            last_line = p.stdout.strip().split('\n')[-1].strip()
            if 'using solvePeriodic2' in last_line:
                likely_oom = True
                logging.info('evaluateJSONMetagen likely OOM (code -11 SIGSEV and in solver)')
            if 'Smoothing mesh' in last_line:
                likely_oom = True
                logging.info('evaluateJSONMetagen likely OOM (code -11 SIGSEV and in Smoothing mesh)')
            if 'Generating object' in last_line:
                likely_oom = True
                logging.info('evaluateJSONMetagen likely OOM (code -11 SIGSEV and in Generating object)')

    except subprocess.TimeoutExpired as te:
        success = False
        reason = 'timeout'
        logging.info(f'evaluateJSONMetagen timed out')
    
    expected_outputs = {
        'material_sphere.obj',
        'structure_info.json',
        'thickened_mc.obj',
        'vox_active_cells.txt',
        'vox_surface.obj',
        'vox_surface_boundary_pairs.txt'
    }
    produced_outputs = {
        fname: os.path.exists(
            os.path.join(outdir, fname)
        )
        for fname in expected_outputs
    }

    if success and not all(produced_outputs.values()):
        success = False
        reason = 'missing_outputs'
        logging.info('evaluateJSONMetagen missing expected outputs')

    log_info['gen_and_sim_likely_oom'] = likely_oom
    log_info['gen_and_sim_success'] = success
    log_info['gen_and_sim_reason'] = reason

    log_info['gen_and_sim_outputs'] = produced_outputs

    return log_info

def run_metagen_program():
    """
    Run a metagen program from stdin and output as json to stdout
    """
    code = sys.stdin.read()
    
    with TemporaryDirectory() as tmp:
        graph_path = os.path.join(tmp, 'graph.json')
        result = [{'type': 'text', 'value': 'Unknown Failure'}]
        try:
            code = f"from metagen import *\n{code}\noutput = ProcMetaTranslator(make_structure())"
            code_env = {}
            exec(code, code_env)
            code_env['output'].save(graph_path)
            generate_and_simulate(graph_path, 64, tmp, simulate = False)
            if os.path.exists(os.path.join(tmp, 'thickened_mc.obj')):
                # Render the material, then convert to b64 png and add to result
                render_path = os.path.join(tmp, 'top_right.png')
                run_render(os.path.join(tmp, 'thickened_mc.obj'), tmp)
                with open(render_path, 'rb') as f:
                    img = f.read()
                    img_b64 = base64.b64encode(img).decode('utf-8')
                    # add html b64 formatting to img_b64
                    img_b64 = f'data:image/png;base64,{img_b64}'
                    result = [{'type': 'image', 'value': img_b64
                    }]
            else:
                result = [{'type': 'text', 'value': 'No render produced'}]
        except  Exception as e:
            return [
                {'type':'text', 'value': str(e)}
            ]
    result = {'content': result}
    print(json.dumps(result))