from pathlib import Path
from typing import Literal, Any, Iterable, Callable, Sequence
import matplotlib.pyplot as plt
import os

from exprutils import Command, parg, karg, run as _run


class MTPCompare(Command):

    fmt = karg(str, "the format of the image", short='f', default='png')

    def __call__(self):
        from exprutils.multfigs import plot_hotmaps
        result_fmt = "Average (ID/OOD): {0:.1f} / {1:.1f}"
        paths = {
            "Direct MTP": "mult-4m-d/sft/final/greedy.eval.json",
            "Concatentated-State MTP": "mult-4m-t/sft/final/greedy.eval.json",
            "Causal-State MTP": "mult-4m/sft-r0/final/greedy-no-reflect.eval.json",
        }
        paths = {k + '\n' + result_fmt: v for k, v in paths.items()}
        plot_hotmaps(
            paths,
            root="out",
            bound=8,
            max_ncol=3,
            col_width=3,
            row_height=3,
            save=f"out/plots/mtp-compare.{self.fmt}",
        )


class SizeCompare(Command):

    fmt = karg(str, "the format of the image", short='f', default='png')

    def __call__(self):
        from exprutils.multfigs import plot_hotmaps
        result_fmt = "Average (ID/OOD): {0:.1f} / {1:.1f}"
        paths = {
            "1M, Non-Reflective": "mult-1m/sft-r0/final/greedy-no-reflect.eval.json",
            "4M, Non-Reflective": "mult-4m/sft-r0/final/greedy-no-reflect.eval.json",
            "16M, Non-Reflective": "mult-16m/sft-r0/final/greedy-no-reflect.eval.json",
        }
        paths = {k + '\n' + result_fmt: v for k, v in paths.items()}
        plot_hotmaps(
            paths,
            root="out",
            bound=8,
            max_ncol=3,
            col_width=3,
            row_height=3,
            save=f"out/plots/size-compare.{self.fmt}",
        )


class PlotReflFreq(Command):
    
    task = parg(str, "the task to plot", choices=("mult", "sudoku"))
    model = karg(str, "the size of the model", default="4m", short='m') 
    fmt = karg(str, "the format of the image", short='f', default='png')
    compare_temperature = karg(bool, "compare low and high temperatures")

    def __call__(self):
        m = self.model
        in_paths = self._get_paths()
        out_path = f"out/plots/refl-freq-{self.task}-{self.model}.{self.fmt}"
        if self.task == "mult":
            from exprutils.multfigs import plot_hotmaps
            plot_hotmaps(
                in_paths,
                root=f"out/mult-{m}",
                bound=8,
                max_ncol=3,
                col_width=3,
                row_height=3,
                save=out_path,
                key="refl_freq",
            )
        elif self.task == "sudoku":
            from exprutils.sudokufigs import plot_hists
            plot_hists(
                in_paths,
                root=f"out/sudoku-{m}",
                bound=54,
                max_ncol=3,
                col_width=3,
                row_height=3,
                save=out_path,
                key="refl_freq",
            )

    def _get_paths(self):
        alg = "greedy-self-reflect" if self.task == "mult" else "sampling-self-reflect"
        if self.compare_temperature:
            paths = {
                "Before GRPO": f"sft-r0.5-v1/final/{alg}.eval.json",
                "After GRPO\nTemperature: 1.25": f"grpo-r0.5-v1/{alg}.eval.json",
                "After GRPO\nTemperature: 1.0": f"grpo-r0.5-t1/{alg}.eval.json"
            }
        else:
            paths = {
                "Before GRPO": f"sft-r0.5-v1/final/{alg}.eval.json",
                "After GRPO": f"grpo-r0.5-v1/{alg}.eval.json",
            }
        return paths


class Table(Command):

    task = parg(str, "the task to tabulate", choices=("mult", "sudoku"))
    root = karg(str, "the root directory of experiment outputs", default="out/")
    metric = karg(str, "the metric being tabulated", default="score",
                  choices=("score", "e+", "e-", "mu"), short='m')
    display = karg(str, "the way to display the table", choices=("pprint", "markdown", "pandas"), default="markdown")
    include = karg(str, "the substrings that the experiment path must include, divided by ','", default="")
    exclude = karg(str, "the substrings that the experiment path must exclude, divided by ','", default="")
    
    type Settinng = Literal["path"]

    def __call__(self):
        from exprutils.figutils import tabulate_results, Result
        
        settings: list[Table.Settinng] = ["path"]
        if self.task == "mult":
            colmap = self.colmap_mult
        elif self.task == "sudoku":
            colmap = self.colmap_sudoku
        else:
            assert False

        stat: str | Callable[[Result], float]
        weight: str | Callable[[Result], float]

        if self.metric == "score":
            stat = "score"
            weight = "count"
        elif self.metric == "e+":
            stat = "refl_freq_fp"
            weight = lambda r: (neg * r.count if (neg := r.get("pi_freq_neg")) is not None else 0)
        elif self.metric == "e-":
            stat = "refl_freq_fn"
            weight = lambda r: ((1 - neg) * r.count if (neg := r.get("pi_freq_neg")) is not None else 0)
        elif self.metric == "mu":
            stat = lambda r: ((1 - neg) if (neg := r.get("pi_freq_neg")) is not None else float('nan'))
            weight = "count"
        else:
            assert False

        paths = sorted(self.iter_expriments())
        results = tabulate_results(
            settings=settings,
            colmap=colmap,
            experiments=[{"path": path} for path in paths],
            stat=stat,
            weight=weight,
            percentage=True,
        )
        if self.display == "pprint":
            from pprint import pprint
            pprint(results)
        else:
            import pandas as pd
            df = pd.DataFrame(results)
            if self.display == "markdown":
                print(df.to_markdown())
            elif self.display == "pandas":
                print(df)
            elif self.display == "hist":
                df.hist()

    def iter_expriments(self):
        include = list(filter(bool, self.include.split(','))) + [self.task]
        exclude = list(filter(bool, self.exclude.split(',')))

        def is_experiment_path(path: str):
            return path.endswith(".eval.json") and \
                all(s in path for s in include) and not any(s in path for s in exclude)

        for root, dirs, files in os.walk(self.root):
            for file in files:
                path = os.path.join(root, file)
                if is_experiment_path(path):
                    yield path
    
    @staticmethod
    def colmap_mult(k: str):
        from exprutils.multfigs import key2digits
        from configs.mult import difficulty
        try:
            dx, dy = key2digits(k)
        except ValueError:
            return
        yield difficulty(dx, dy)

    @staticmethod
    def colmap_sudoku(k: str):
        from exprutils.sudokufigs import key2blanks
        from configs.sudoku import difficulty
        try:
            b = key2blanks(k)
        except ValueError:
            return
        yield difficulty(b)


if __name__ == "__main__":
    _run(globals())
