import click
import concurrent
import itertools
import json
import logging
import random
import string
import time
import tqdm

from fests.matcher.strem import Matcher
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from pandas import DataFrame
from pathlib import Path
from statistics import mean, median, mode
from typing import List, Optional

logging.disable()

def statistics(values):
    """Compute the statistics of the values.
    """

    return {
        "total": len(values),
        "mean": mean(values) if values else 0,
        "median": median(values) if values else 0,
        "mode": mode(values) if values else 0,
        "min": min(values) if values else 0,
        "max": max(values) if values else 0,
    }


def reasoning_to_string(reasoning):
    """
    Recursively converts a reasoning dict into a natural-language phrase,
    handling comparators, distances, intersections, and simple statements.
    """
    base = reasoning.get("statement", "")
    children = reasoning.get("children", [])

    # Recurse to get child descriptions
    child_texts = [reasoning_to_string(child) for child in children if child.get("statement")]

    # Comparators
    if base == "less than" and len(children) == 2:
        left, right = child_texts
        return f"{left} is less than {right}"
    if base == "greater than" and len(children) == 2:
        left, right = child_texts
        return f"{left} is greater than {right}"

    # Distance between two children
    if base == "distance between" and len(child_texts) == 2:
        return f"distance between {child_texts[0]} and {child_texts[1]}"

    # Intersection of two children
    if base == "intersection" and len(child_texts) == 2:
        return f"intersection between {child_texts[0]} and {child_texts[1]}"

    # Single-child modifier
    if len(child_texts) == 1:
        return f"{base} {child_texts[0]}"

    # Fallback: join multiple children naturally
    if child_texts:
        if base in ("and", "or"):
            connector = " and " if base == "and" else " or "
            return connector.join(child_texts)
        return f"{base} " + " and ".join(child_texts)

    # Leaf node
    return base


def naturalize(json_data):
    """
    Converts the JSON explanation into one or more natural‑language sentences,
    grouping identical reasoning or—and if exactly two distinct statements—chaining them.
    """
    # Build and sort list of (index, description)
    desc_by_index = sorted(
        ((item["index"], reasoning_to_string(item["reasoning"]))
         for item in json_data.get("statements", [])),
        key=lambda x: x[0]
    )

    # Special case: exactly two statements with different descriptions
    if len(desc_by_index) == 2 and desc_by_index[0][1] != desc_by_index[1][1]:
        i1, d1 = desc_by_index[0]
        i2, d2 = desc_by_index[1]
        return f"{d1} at index {i1} followed by {d2} at index {i2}."

    # Otherwise, group identical descriptions across indices
    grouped = {}
    for idx, desc in desc_by_index:
        grouped.setdefault(desc, []).append(idx)

    fragments = []
    for desc, indices in grouped.items():
        # collapse consecutive runs
        for run in (list(g) for _, g in itertools.groupby(indices, key=lambda i, c=iter(indices): i - next(c))):
            if len(run) == 1:
                fragments.append(f"At index {run[0]}, {desc}.")
            else:
                fragments.append(f"From index {run[0]} to {run[-1]}, {desc}.")

    return " ".join(fragments)

class IdentifierGenerator:
    """A class for identifier generation.
    """

    def __init__(self) -> None:
        self.idmap = {}
        self.characters = list(
            string.ascii_lowercase +
            string.ascii_uppercase +
            string.digits
        )

        self.positions = [0, 0]

    def increment(self) -> None:
        """Increment the generator.
        """

        self.positions[0] += 1

        if self.positions[0] >= len(self.characters):
            self.positions[0] = 0
            self.positions[1] += 1

            if self.positions[1] >= len(self.characters):
                self.positions[1] = 0

    def identifier(self, identifier: str) -> str:
        """Generate a new identifier.

        This generation function creates a two digit-based identifier from a
        previous identifier.
        """

        if self.idmap.get(identifier) is None:
            # The identifier does not exist.
            #
            # Since this is the first time "seeing" the identifier, we need to
            # create a new identifier for it.
            self.idmap[identifier] = str(self.characters[self.positions[0]]) + str(self.characters[self.positions[1]])
            self.increment()

        return self.idmap[identifier]


@dataclass
class Options:
    """A set of options.
    """

    path: Path
    output: Path
    jobs: int
    recursive: bool
    seed: int
    nframes: List[int]
    limit: Optional[int]

@dataclass
class Query:
    """A query.
    """

    spre: str
    nl: str
    category: str


class Formatter:
    """A class to format to LLM prompting.
    """

    def __init__(self, prompt: str) -> None:
        """Initialize the formatter.
        """

        self.prompt = prompt

    def rowify(self, datastream) -> str:
        """Row-ify the datastream.
        """

        generator = IdentifierGenerator()

        rows = []
        for frame in datastream["frames"]:
            for sample in frame["samples"]:
                for annotation in sample["annotations"]:
                    x = annotation["bbox"]["region"]["center"]["x"]
                    y = annotation["bbox"]["region"]["center"]["y"]
                    w = annotation["bbox"]["region"]["dimensions"]["w"]
                    h = annotation["bbox"]["region"]["dimensions"]["h"]

                    rows.append({
                        "index": frame["index"],
                        "identifier": generator.identifier(annotation["identifier"]),
                        "class": annotation["class"],
                        "xmin": int(x - (w / 2)),
                        "ymin": int(y - (h / 2)),
                        "xmax": int(x + (w / 2)),
                        "ymax": int(y + (h / 2)),
                    })

        # Export to a CSV-styled string.
        #
        # This converts the rows into a DataFrame and used the methods of the
        # structure to produce a CSV-formatted string, accordingly.
        return DataFrame(rows).to_csv(index=False)


    def fmt(self, datastream, query: Query) -> str:
        """Format the data into an LLM prompt.
        """

        prompt = self.prompt

        prompt += "---" + "\n"
        prompt += "<root>" + "\n"
        prompt += "\t<query>" + query.nl + "</query>" + "\n"
        prompt += "<data>" + "\n"
        prompt += self.rowify(datastream)
        prompt += "</data>" + "\n"
        prompt += "</root>" + "\n"

        return {
            "input": prompt
        }
    
class Processor:
    """A class to process the data.
    """

    def __init__(self, options: Options, nframes: Optional[int], query: Query, formatter: Formatter) -> None:
        """Initialize the processor.
        """

        self.options = options
        self.nframes = nframes

        self.query = query
        self.formatter = formatter

    def slice(self, frames):
        """Random slice/sample the data.
        """

        if self.nframes is None:
            self.nframes = len(frames)

        start = random.randint(0, len(frames) - self.nframes)
        end = start + self.nframes

        return frames[start:end]

    def process(self, path: Path):
        """Process the data.
        """
        t = time.perf_counter()

        with open(path, "r") as f:
            datastream = json.load(f)

            # Sample (slice) frames.
            #
            # Here, we want to randomly select N frames from the data to maintain
            # randomness within the data.
            datastream["frames"] = self.slice(datastream["frames"])

            # Run the Matcher.
            #
            # This runs the STREM framework to retrieve the set of matches
            # alongside the explanations.
            matches = Matcher(
                query=self.query.spre,
                datastream=datastream,
            ).run()

            # Export the data.
            #
            # This transforms and writes the data into a structure prepared for
            # LLM training/finetuning.
            data = {
                "input": self.formatter.fmt(datastream, self.query),
                "output": [list(range(m["frames"][0]["index"], m["frames"][-1]["index"] + 1)) for m in matches],
                "explanations": [naturalize(m["explanation"]) for m in matches]
            }

            self.export(data, path)

        return {
            "elapsed": time.perf_counter() - t,
            "path": str(path),
            "frames": {
                "nframes": self.nframes,
                "percents": {
                    "non-empty": 1.0,
                }
            },
            "query": {
                "spre": self.query.spre,
                "category": self.query.category,
            },
            "statistics": {
                "num-matches": len(matches),
                "match-length": statistics([m["frames"][-1]["index"] - m["frames"][0]["index"] + 1 for m in matches]) if matches else statistics(matches),
            }
        }

    def export(self, data, path: Path) -> None:
        """Export the data.
        """

        output = self.options.output.joinpath("data").joinpath(f"{self.nframes:03}").joinpath(path.parent.name)
        output.mkdir(parents=True, exist_ok=True)

        # Set filename.
        #
        # This includes the channel sourced from the dataset used as well as the
        # category.
        filename = output.joinpath(f"{path.stem}-{self.query.category}.json")

        with open(filename, "w") as f:
            json.dump(data, f)
        


class Application:
    """A class to represent an application.
    """

    def __init__(self, options: Options, queries, prompt: str) -> None:
        """Initialize the application.
        """

        self.options = options
        self.queries = queries
        self.prompt = prompt

    def run(self) -> None:
        """The entrypoint for the application.
        """
        t = time.perf_counter()
        random.seed(self.options.seed)

        # Create the output directory.
        #
        # This will avoid any errors if the directory does not already exist,
        # accordingly.
        self.options.output.mkdir(parents=True, exist_ok=True)

        # Find the files.
        #
        # This will find the files either recursively or not. In any case, this
        # will return a list of files, accordingly.
        filenames = self.options.path.rglob("*.json") if self.options.recursive else self.options.path.glob("*.json")
        filenames = list(filenames)

        # Limit the number of scenes.
        #
        # This will select the `self.options.limit` number of scenes form the set
        # of collected scenes, accordingly.
        if self.options.limit:
            filenames = random.sample(filenames, min(self.options.limit, len(filenames)))

        # Process the filenames.
        #
        # For each file, process it. This includes transformation of the data as
        # well as writing to a new file, accordingly.
        dataset = []
        with ProcessPoolExecutor(max_workers=self.options.jobs) as executor:
            futures = []
            for filename in filenames:
                for nframes in self.options.nframes:
                    for query in self.queries:
                        # Create a task.
                        #
                        # This creates a new task that for a new Processor with
                        # the associated filename, nframes, and query.
                        futures.append(executor.submit(
                            Processor(
                                options=self.options,
                                nframes=nframes,
                                query=Query(
                                    spre=query["spre"],
                                    nl=query["nl"],
                                    category=query["category"],
                                ),
                                formatter=Formatter(
                                    prompt=self.prompt
                                )
                            ).process,
                            filename
                        ))

            # Process the task result.
            #
            # This is used to process the result as the task is completed,
            # accordingly.
            for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
                dataset.append(future.result())

        # Post-process stats.
        #
        # This collects and collates the statistics for the data generated from
        # the set path provided.
        statistics = {
            "elapsed": time.perf_counter() - t,
            "seed": self.options.seed,
            "path": str(self.options.path),
            "statistics": self.statistics(dataset),
            "processed": dataset
        }

        with open(self.options.output.joinpath("stats.json"), "w") as f:
            json.dump(statistics, f, indent=4)

    def statistics(self, dataset):
        """Compute the statistics on the dataset.
        """

        tmp = {}
        for data in dataset:
            tmp.setdefault(data["query"]["category"], []).append(data)

        categories = {}
        for key, values in tmp.items():
            categories[key] = {
                "proportion": len(values) / sum([len(x) for x in tmp.values()]),
                "num-matches": statistics([x["statistics"]["num-matches"] for x in values]),
                "percents": {
                    "non-empty": len(list(filter(lambda x: x["statistics"]["num-matches"] > 0, values))) / len(values) if values else 0,
                }
            }

        return {
            "dataset": {
                "num-files": len(dataset),
                "num-matches": statistics([data["statistics"]["num-matches"] for data in dataset]),
                "percents": {
                    "non-empty": len(list(filter(lambda x: x["statistics"]["num-matches"] > 0, dataset))) / len(dataset) if dataset else 0,
                },
                "categories": categories,
            },
        }


@click.command()
@click.argument("path", type=Path)
@click.option("-o", "--output", type=Path, default=Path("output"), help="A directory to output contents to.")
@click.option("-r", "--recursive", is_flag=True, help="Process the files recursively.")
@click.option("-j", "--jobs", type=int, default=1, help="Number of jobs to run in parallel.")
@click.option("-s", "--seed", type=int, help="The number to seed the processor.")
@click.option("-n", "--num-frames", "nframes", type=int, multiple=True, default=[None], help="The number of frames to sample.")
@click.option("-c", "--context", type=Path, default=Path("./"), help="The working directory.")
@click.option("-l", "--limit", type=int, default=None, help="The limit on the number of samples.")
def main(path, output, recursive, jobs, seed, nframes, context, limit) -> None:
    """A tool to process STREM data into an LLM dataset.
    """

    # Generate seed if necessary.
    #
    # If the seed is not provided, then the seed needs to be set by randomly
    # selecting an integer.
    if seed is None:
        seed = random.randint(0, 2 ** 32 - 1)

    # Load special files.
    #
    # These are files that need to be loaded EVERY time, so this is non-optional
    # for the user to select. We should attempt to load them every time.
    with open(context.joinpath("queries.json"), "r") as f:
        queries = json.load(f)

    with open(context.joinpath("prompt.txt"), "r") as f:
        prompt = f.read()

    # Run the application.
    #
    # This runs the application core function.
    app = Application(
        options=Options(
            path=path,
            output=output,
            jobs=jobs,
            recursive=recursive,
            seed=seed,
            nframes=list(nframes),
            limit=limit
        ),
        queries=queries["train"],
        prompt=prompt,
    ).run()


if __name__ == r"__main__":
    main()
