from final_results import ModelRuns, iLab_RESULTS
from typing import List, Dict
from cross_validate import BestRuns
import numpy as np
from scipy.stats import sem
import pandas as pd


class ModelResults:
    def __init__(self, model_runs: ModelRuns):
        self.model_runs = model_runs

        self.metric_to_values: Dict[str, List[float]] = self.load_metrics()

    def __getattr__(self, name):
        """Return attributes from dataclass if they exist"""
        try:
            return getattr(self.model_runs, name)
        except AttributeError:
            return self.name

    def load_metrics(self) -> Dict[str, List[float]]:
        metrics = dict()
        for run_dir in self.model_runs.run_dirs:
            run_metrics = dict()
            train_metrics = BestRuns.load_metrics_json(run_dir, prefix="train_")
            eval_metrics = BestRuns.load_metrics_json(run_dir, prefix="eval_")

            run_metrics.update(train_metrics)
            run_metrics.update(eval_metrics)

            metrics = self._update_metrics(run_metrics, metrics)
        return metrics

    def _update_metrics(self, run_metrics, metrics) -> dict():
        for metric_name in run_metrics:
            if metric_name == "model_name":
                continue

            if metric_name in metrics:
                existing_values = metrics[metric_name]
                existing_values.append(run_metrics[metric_name])
                metrics[metric_name] = existing_values
            else:
                metrics[metric_name] = [run_metrics[metric_name]]
        return metrics

    def mean(self, metric_name: str) -> float:
        return np.mean(self.metric_to_values[metric_name])

    def stderr(self, metric_name: str) -> float:
        return sem(self.metric_to_values[metric_name])


class LieTable:
    """Creates a table summarizing Lie performance against the baseline"""

    def __init__(self, model_runs: List[ModelRuns]):
        self.model_runs = model_runs
        self.model_results = [ModelResults(runs) for runs in model_runs]

        self.table_metrics = [
            "diverse_2d_train_canonical_top_1_accuracy",
            "test_diverse_2d_top_1_accuracy",
            "test_canonical_top_1_accuracy",
        ]
        # names to use when displaying metrics
        self.metric_display_names = [
            "known instance in new pose",
            "unknown instance",
            "unknown instance in seen pose",
        ]

        self.column_names = self.create_column_names()
        self.original_table_df = self.make_table()

        self.display_columns = [
            "model_name",
            "eval_type",
            "diverse_proportion",
        ] + self.metric_display_names
        self.table_df = self.original_table_df[self.display_columns]

    def make_table(self) -> pd.DataFrame:
        entries = self.gather_entries()
        df = pd.DataFrame(entries, columns=self.column_names)
        df = self.add_differences(df)
        return df

    def gather_entries(self) -> List[list]:
        """Creates a list of entries where first list contains column names"""
        entries = []
        for result in self.model_results:
            entry = [result.name, result.eval_type, result.diverse_proportion]
            for metric in self.table_metrics:
                mean = result.mean(metric)
                stderr = result.stderr(metric)
                str_rep = f"{mean * 100:.1f} ± {{\scriptsize{stderr * 100 :.1f}}}"
                entry.append(mean)
                entry.append(stderr)
                entry.append(str_rep)
            entries.append(entry)
        return entries

    def create_column_names(self) -> List[str]:
        column_names = ["model_name", "eval_type", "diverse_proportion"]
        for i, metric in enumerate(self.table_metrics):
            column_names.append(f"{metric}")
            column_names.append(f"{metric} stderr")
            column_names.append(self.metric_display_names[i])
        return column_names

    def group_cols(self) -> pd.DataFrame:
        # group columns by eval type and proportion of diversity
        table = self.table_df
        table = table.set_index(["eval_type", "diverse_proportion", "model_name"])
        table = table.unstack("eval_type").unstack("diverse_proportion")
        table = table.swaplevel(i=0, j=-1, axis=1).swaplevel(i=0, j=1, axis=1)
        return table

    def add_differences(self, table: pd.DataFrame) -> pd.DataFrame:
        model_pairs = [("MAE Lie", "MAE"), ("SimCLR Lie", "SimCLR")]
        for lie_model, baseline_model in model_pairs:
            for i, row in table[table["model_name"] == lie_model].iterrows():
                diverse_proportion = row["diverse_proportion"]
                eval_type = row["eval_type"]
                table = self.add_difference(
                    table, diverse_proportion, eval_type, lie_model, baseline_model
                )
        return table

    def add_difference(
        self,
        table: pd.DataFrame,
        diverse_proportion: float,
        eval_type: str,
        lie_model="MAE Lie",
        baseline_model="MAE",
    ) -> pd.DataFrame:
        matches = (table["eval_type"] == eval_type) & (
            table["diverse_proportion"] == diverse_proportion
        )
        for display_name, metric_name in zip(
            self.metric_display_names, self.table_metrics
        ):
            lie_metric = table[(table["model_name"] == lie_model) & matches][
                metric_name
            ].item()
            try:
                base_metric = table[(table["model_name"] == baseline_model) & matches][
                    metric_name
                ].item()
            except Exception as e:
                print(
                    "matching rows for ", baseline_model, eval_type, diverse_proportion
                )
                print(table[(table["model_name"] == baseline_model) & matches])
                raise e
            absolute = float(lie_metric - base_metric)
            relative = float(lie_metric - base_metric) / base_metric + 1
            color = self.set_color(relative)
            sign = "" if relative < 1.00 else "+"

            absolute_str = (
                "\\textbf{"
                + "\\textcolor"
                + f"{{{color}}}{{{sign}{ 100 * absolute:.1f}}}"
                + "}"
            )
            relative_str = (
                "\\textbf{" + "\\textcolor" + f"{{{color}}}{{{relative:.2f}x}}" + "}"
            )

            current = table[(table["model_name"] == lie_model) & matches][
                display_name
            ].item()
            new = f"{current} ({absolute_str}, {relative_str})"

            i = table[(table["model_name"] == lie_model) & matches].index.values[0]
            table.at[i, display_name] = new
        return table

    def set_color(self, relative_change: float) -> str:
        if relative_change < 0.95:
            return "Red"
        elif relative_change < 1.0:
            return "Gray"
        return "DarkGreen"

    def to_latex(self, table=None, caption="") -> str:
        """call via print(results.to_latex() to obtain a LaTeX table"""
        if table is None:
            table = self.group_cols()
        latex_table = table.to_latex(
            label=caption,
            index=True,
            # show floats as percentage with 2 decimal values
            float_format="{:.2%}".format,
            # escape special chars for LaTeX
            escape=False,
            # for NaN use
            na_rep="-",
            # longtable=True,
            multicolumn_format="c",
            caption=caption,
            # don't repeat multi-column names
            sparsify=True,
        )
        return latex_table


class LieLinearEvalTable(LieTable):
    evaluation_type = "linear_eval"

    def __init__(self, model_runs: List[ModelRuns]):
        super().__init__(model_runs)
        self.original_table_df = self.table_df

        self.table_df = self.original_table_df[
            self.original_table_df["eval_type"] == self.evaluation_type
        ]
        # drop unknown instance in seen pose column
        self.table_df = self.table_df.loc[
            :, self.table_df.columns != "unknown instance in seen pose"
        ]
        # drop eval_type column
        self.table_df = self.table_df.loc[:, self.table_df.columns != "eval_type"]
        self.table_grouped_cols_df = self.group_cols()

    def group_cols(self) -> pd.DataFrame:
        # group columns by eval type and proportion of diversity
        table = self.table_df
        table = table.set_index(["diverse_proportion", "model_name"])
        table = table.unstack("diverse_proportion")
        # set diversity as top level column
        table = table.swaplevel(i=0, j=-1, axis=1)
        # sort by increasing diversity
        table = table.sort_values(by=[table.columns[0]])
        # sort by model names so baselines are next to Lie
        table = table.sort_values(by=["model_name"])
        # drop unknow instance in seen pose group
        return table

    def reformat_table(self, table) -> pd.DataFrame:
        model_order = [
            "MAE Lie",
            "MAE",
            "SimCLR Lie",
            "SimCLR",
            "SimCLR Frames",
            "SimCLR Frames (matched # params)",
        ]
        table = table.loc[model_order]

        table = table.rename(columns={"model_name": "model"})
        table = table.rename(columns={"diverse_proportion": "diverse proportion"})
        table = table.rename(index={
                "SimCLR Lie": "\\textbf{SimCLR Lie}",
                "MAE Lie": "\\textbf{MAE Lie}",
                "SimCLR Frames (matched # params)": "SimCLR Frames \n (matched \# params)",
            }
        )
        return table

    def to_latex(self, table=None):
        pd.set_option("display.max_colwidth", None)
        if table is None:
            table = self.table_grouped_cols_df
        table = self.reformat_table(table)
        return super().to_latex(
            table=table,
            caption=self.evaluation_type.replace("_", " "),
        )


class FinetuningTable(LieLinearEvalTable):
    evaluation_type = "finetuning"

    def to_latex(self):
        pd.set_option("display.max_colwidth", None)
        table = self.table_grouped_cols_df
        return super().to_latex(
            table=table,
        )


class SimCLRFramesLinearTable(LieLinearEvalTable):
    evaluation_type = "linear_eval"
    diverse_proportion = 0.5
    model_names = [
        "SimCLR Lie",
        "SimCLR",
        "SimCLR Frames",
        "SimCLR Frames (matched # params)",
    ]

    def __init__(self, model_runs: List[ModelRuns]):
        super().__init__(model_runs)
        self.all_rows_table_df = self.original_table_df

        self.table_df = self.prepare_table(self.all_rows_table_df)
        self.table_grouped_cols_df = self.table_df

    def prepare_table(self, full_table):
        table = self.filter_table_rows(full_table)
        table = self.filter_table_columns(table)
        table = table.sort_values(by=["model_name"], ascending=False)
        return table

    def filter_table_rows(self, table: pd.DataFrame):
        table = table[table["eval_type"] == self.evaluation_type]
        # filter only MAE model
        table = table[table["model_name"].isin(self.model_names)]
        table = table[table["diverse_proportion"] == self.diverse_proportion]
        return table

    def filter_table_columns(self, table: pd.DataFrame) -> pd.DataFrame:
        table = table[self.display_columns]
        # drop eval_type column
        table = table.loc[:, table.columns != "eval_type"]
        # drop diverse column
        table = table.loc[:, table.columns != "diverse_proportion"]
        return table

    def reformat_table(self, table) -> pd.DataFrame:
        # reorder rows
        table = table.set_index("model_name").loc[self.model_names].reset_index()

        # rename
        table = table.rename(columns={"model_name": "model"})
        table = table.rename(columns={"diverse_proportion": "diverse proportion"})
        table = table.replace("SimCLR", "\rowcolor{LightLightGray} SimCLR")
        table = table.replace("SimCLR Lie", "\\textbf{SimCLR Lie}")
        table = table.replace(
            "SimCLR Frames (matched # params)", "SimCLR Frames \n (matched \# params)"
        )
        return table

    def to_latex(self):
        pd.set_option("display.max_colwidth", None)
        table = self.table_df
        table = self.reformat_table(table)

        latex_table = table.to_latex(
            label=self.__class__.__name__,
            index=False,
            # show floats as percentage with 2 decimal values
            float_format="{:.2%}".format,
            # escape special chars for LaTeX
            escape=False,
            # for NaN use
            na_rep="-",
            # longtable=True,
            multicolumn_format="c",
            caption=self.evaluation_type.replace("_", " "),
            # don't repeat multi-column names
            sparsify=True,
        )
        return latex_table


class SimCLRFramesFinetuneTable(SimCLRFramesLinearTable):
    evaluation_type = "finetuning"


class iLabTable:
    def __init__(self, model_runs: List[ModelRuns] = iLab_RESULTS):
        self.model_runs = model_runs
        self.model_results = [ModelResults(runs) for runs in model_runs]

    def show_results(self):
        for result in self.model_results:
            print(
                result.name,
                f"{result.mean('test_top_1_accuracy'):.1%}",
                f"{result.stderr('test_top_1_accuracy'):.2%}",
            )
            print()
