"""
Functions to sample from parametric programs.
"""

import yaml
import inspect
import logging
import random
from fractions import Fraction
import itertools
import ast
import random
from fractions import Fraction
from typing import List, Dict, Tuple
import os
import math
from metagen.util import format_float

def get_function_args_and_positions_and_indent(
    source: str,
    func_name: str
) -> Tuple[List[Dict[str, str]], int, int, str]:
    """
    Parse the given source code, find the function named `func_name`, and return:
      1. A list of argument dicts: {'name': ..., 'type': ..., 'default': ...}
      2. The start character index of the "def" keyword
      3. The end character index of the ":" that terminates the signature
      4. The indentation string used inside the function body
    """
    tree = ast.parse(source)
    lines = source.splitlines(keepends=True)

    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef) and node.name == func_name:
            # Compute start position of "def"
            start_pos = sum(len(lines[i]) for i in range(node.lineno - 1)) + node.col_offset

            # Locate end of signature (matching closing paren + colon)
            depth = 0
            seen_paren = False
            end_pos = None
            for idx in range(start_pos, len(source)):
                ch = source[idx]
                if ch == '(':
                    depth += 1
                    seen_paren = True
                elif ch == ')' and seen_paren:
                    depth -= 1
                elif ch == ':' and seen_paren and depth == 0:
                    end_pos = idx
                    break
            if end_pos is None:
                raise SyntaxError(f"Could not locate end of signature for function '{func_name}'")

            # Determine indentation of first body line
            if node.body:
                first_line = lines[node.body[0].lineno - 1]
                indent = first_line[:len(first_line) - len(first_line.lstrip())]
            else:
                indent = ''

            # Helper to unparse annotations/defaults
            def unp(n):
                return ast.unparse(n) if n is not None else None

            args_list: List[Dict[str, str]] = []
            # Positional-only + regular args
            posonly = getattr(node.args, 'posonlyargs', [])
            regular = node.args.args
            all_pos = posonly + regular
            defaults = [None] * (len(all_pos) - len(node.args.defaults)) + list(node.args.defaults)
            for arg, default in zip(all_pos, defaults):
                args_list.append({
                    'name': arg.arg,
                    'type': unp(arg.annotation),
                    'default': unp(default)
                })
            # *args
            if node.args.vararg:
                args_list.append({
                    'name': f"*{node.args.vararg.arg}",
                    'type': unp(node.args.vararg.annotation),
                    'default': None
                })
            # Keyword-only args
            for arg, default in zip(node.args.kwonlyargs, node.args.kw_defaults):
                args_list.append({
                    'name': arg.arg,
                    'type': unp(arg.annotation),
                    'default': unp(default)
                })
            # **kwargs
            if node.args.kwarg:
                args_list.append({
                    'name': f"**{node.args.kwarg.arg}",
                    'type': unp(node.args.kwarg.annotation),
                    'default': None
                })

            return args_list, start_pos, end_pos, indent

    raise ValueError(f"Function '{func_name}' not found in the provided source.")


def sample_float(value: float, keep_prob: float) -> str:
    """
    With probability `keep_prob`, return `value`; otherwise sample a new uniform random float.
    In either case, approximate the chosen float by a Fraction with small integers and
    return it as "numerator/denominator".

    Args:
        value (float): A float in [0, 1].
        keep_prob (float): Probability in [0, 1] of keeping `value`.

    Returns:
        str: Rational approximation, e.g. "3/7".
    """
    #if not 0.0 <= value <= 1.0:
    #    raise ValueError("`value` must be between 0 and 1")
    if not 0.0 <= keep_prob <= 1.0:
        raise ValueError("`keep_prob` must be between 0 and 1")

    # decide whether to keep or resample
    if random.random() < keep_prob:
        chosen = value
    else:
        chosen = random.random()

    return format_float(chosen)
    # approximate with small-integer numerator/denominator (denominator ≤ 20)
    #frac = Fraction(chosen).limit_denominator(20)
    #return f"{frac.numerator}/{frac.denominator}"

def sample_int(value: int, keep_prob: float) -> str:
    """
    With probability `keep_prob`, return `value`; otherwise sample a new uniform random int
    up to the closest multiple of 5.

    Args:
        value (int): An integer.
        keep_prob (float): Probability in [0, 1] of keeping `value`.
    
    Returns:
        str: Potentially resampled integer as a string, e.g "4".
    """
    if not 0.0 <= keep_prob <= 1.0:
        raise ValueError("`keep_prob` must be between 0 and 1")
    
        # decide whether to keep or resample
    if random.random() < keep_prob:
        chosen = value
    else:
        chosen = random.randint(0 if value >= 0 else - (-value) // 5 * 5 + 5, value // 5 * 5 + 5)
    
    return str(chosen)

def sample_programs(data_path, generator_path, keep_prob = 0.5, count = 10):
    full_path = os.path.join(data_path, generator_path)
    generator_path = '/' + generator_path.rstrip('/')
    with open(full_path,'r') as f:
        source = f.read()
    source = source[:source.index(END_STRUCTURE_DESCRIPTION)]
    args_list, start_pos, end_pos, indent = get_function_args_and_positions_and_indent(source, 'make_structure')
    header = source[:start_pos]
    footer = source[end_pos+1:]

    # We want an expected value of `count` samples, constructed by taking m samples of the k parameters
    # there is a p = keep_prob**k chance that we simply resample the default case, so a 1-p chance of keeping each
    # so our expected count is actually (1-p)m**k = count
    # This does not take into account the probability of duplicate samples, which is actually non-neglibible because
    # of limited range for integer samples, and the small-rational representation
    k = len(args_list)
    p = keep_prob**k
    m = math.ceil(math.pow(count / (1-p),1/k))

    default_arg_vals = []
    arg_types = []
    for arg in args_list:
        arg_val = eval(arg['default'])
        default_arg_vals.append(arg_val) # Will be used to check if we are re-sampling the deafult values (and so should skip)
        arg_type = type(arg_val).__name__ if arg['type'] is None else arg['type']
        arg_types.append(arg_type)
        arg_samples = []#[sample_float(arg_val, 1.0) if arg_type == 'float' else sample_int(arg_val, 1.0)]
        for _ in range(m):
            arg_samples.append(sample_float(arg_val, keep_prob) if arg_type == 'float' else sample_int(arg_val, keep_prob))
        arg['samples'] = arg_samples
    
    samples = {}
    used_tags = set()
    arg_combos = itertools.product(*[arg['samples'] for arg in args_list])
    for arg_set in arg_combos:
        sampled_arg_values = [eval(a) for a in arg_set]
        if all([a1==a2 for a1,a2 in zip(sampled_arg_values, default_arg_vals)]):
            continue # skip if we are just re-sampling the default values
        assignments = []
        tags = []
        name_parts = []
        formatted_parameters = {}
        for i,arg in enumerate(args_list):
            name = arg['name']
            value = arg_set[i]
            assignments.append(name + ':' + arg_types[i] + '=' + value)
            #assignments.append(indent + name + ' = ' + value)
            tags.append(name + '_' + value.replace('/','-'))
            name_parts.append(name[0] + '_' + value.replace('/','-'))
            formatted_parameters[name] = value
        args = ', '.join(assignments)
        tag = '_'.join(tags)
        filename = '_'.join(name_parts)
        if tag not in used_tags:
            used_tags.add(tag)
            sample_code = header + f"def make_structure({args}) -> Structure:" + footer
            
            sample_code = update_header(sample_code, generator_path, formatted_parameters)
            try:
                test_code = "\nstruct = make_structure()"
                exec(sample_code + test_code, {})
            except Exception as e:
                continue
            samples[filename] = sample_code

    return samples

test_harness = """
from tempfile import TemporaryDirectory
import os
with TemporaryDirectory() as tmpdir:
    graph_path = os.path.join(tmpdir, 'graph.png')
    ProcMetaTranslator(struct).save(graph_path)
    assert os.path.exists(graph_path)
"""

END_STRUCTURE_DESCRIPTION = "# --- END: structure description ---"
BEGIN_SCRATCH_SPACE = "# --- BEGIN: scratch space ---"

def split_header(code: str) -> Tuple[str, str]:
    START = "'''"
    END = "'''"
    block_start = code.index(START) + len(START)
    block_end = code[block_start:].index(END) + block_start
    header = code[block_start:block_end]
    program = code[block_end+len(END):]
    return header, program


def update_header(code, script_path, parameters):
    header, program = split_header(code)
    metadata = yaml.safe_load(header)
    generator_info = {
        'script': script_path,
        'parameters': parameters
    }
    metadata['file_info']['generator_info'] = generator_info
    metadata['file_info']['parametric_sample'] = True
    header = "'''\n" + yaml.safe_dump(metadata) + "'''"
    return (header + program).strip()


def get_parameters(fn):
    sig = inspect.signature(fn)
    params = {
        k: (v.annotation, v.default)
        for k, v in sig.parameters.items()
    }
    return params

def header_block(code):
    block_start = code.index("'''") + 3
    block_end = code[block_start:].index("'''") + block_start
    return code[block_start:block_end], block_end + 3

def read_code(path):
    with open(path, 'r') as f:
        code_file = f.read()
    header, header_end_idx = header_block(code_file)
    header = yaml.safe_load(header)
    code_and_footer = code_file[header_end_idx:]
    code_end = code_and_footer.index("# --- END: structure description ---")
    code = code_and_footer[:code_end]
    footer_tag = '# --- BEGIN: scratch space ---'
    footer_start_idx = code_and_footer.index(footer_tag) + len(footer_tag)
    footer = code_and_footer[footer_start_idx:]
    return header, code, footer

def sample(generator_path, datapath, n_samples: int = 10):
    """
    Sample from a parametric generator.
    Args:
        parametric_generator (function): The parametric generator function.
        n_samples (int): The number of samples to generate.
    Returns:
        list: A list of generated samples.
    """
    header, code, footer = read_code(generator_path)
    env = {}
    try:
        exec(code, env)
        make_structure = env['make_structure']
    except Exception as e:
        # TODO - handle properly
        return
    

    # TODO - add valid ranges to the header block and sample from that
    params = get_parameters(make_structure)
    samples = []
    for _ in range(n_samples):
        sample = {}
        for name, (param_type, default) in params.items():
            if default is inspect.Parameter.empty:
                sample[name] = None
            else:
                sample[name] = default
        samples.append(sample)
    return samples

def run_sample_parametric_programs():
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument('--basedir', '-b', type=str, help='Path to dataset base directory.')
    parser.add_argument('--searchdir', '-s', type=str, help='Path within dataset base directory to search for parametric programs.')
    parser.add_argument('--outdir', '-o', type=str, help='Path within dataset base directory to write sampled programs.')
    parser.add_argument('--keep_prob', '-p', type=float, default = 0.5, help='Probability of not changing a default parameter.')
    parser.add_argument('--number', '-n', type=int, default=50, help='Target number of samples to take.')
    args = parser.parse_args()
    sample_parametric_programs(args)


def sample_parametric_programs(args):
    from metagen.util import list_all_filtered
    searchdir = os.path.join(args.basedir, args.searchdir.lstrip('/'))
    outdir = os.path.join(args.basedir, args.outdir.lstrip('/'))
    programs = list_all_filtered(searchdir, ['**/program.py'])
    basestems = [p[len(args.basedir):].lstrip('/') for p in programs]
    stems = [p[len(searchdir):].lstrip('/') for p in programs]
    progdirs = [os.path.join(outdir, stem[:-len('/program.py')]) for stem in stems]
    outpaths = [os.path.join(d, 'program.py') for d in progdirs]
    
    for program, stem, basestem, progdir in zip(programs, stems, basestems, progdirs):
        samples = sample_programs(args.basedir, basestem.lstrip('/'), args.keep_prob, args.number)
        for tag,code in samples.items():
            out_dir_name = progdir.split('/')[-1] + '_' + tag
            out_dir = os.path.join(progdir, out_dir_name)
            out_path = os.path.join(out_dir, 'program.py')
            os.makedirs(out_dir, exist_ok=True)
            with open(out_path, 'w') as f:
                f.write(code)
    
