"""
1. Bar plot with std of diverse poses, grouped into 2d and 3d
    - train_diverse, test_diverse, diverse_train_canonical
2. Bar plot of canonical poses
"""
from typing import List, Optional, Tuple, Dict
from collections import OrderedDict
from sklearn import linear_model

from sklearn.model_selection import learning_curve
from cross_validate import BestRuns, DiverseRun
import plotly.graph_objects as go
from scipy.stats import sem
import numpy as np
from plotly.subplots import make_subplots
import pandas as pd
import matplotlib.pyplot as plt
from final_results import ModelRuns
from sweeps.sweep_results import BEST_CANONICAL_RUNS, DIVERSE_SWEEP_TO_DIR, SWEEP_TO_DIR
from visu import make_results


class Plots:
    def __init__(
        self,
        sweep_dir: str,
        validation_metric: str = "val_canonical_loss",
        metric_to_plot: str = "top_1_accuracy",
        sweep_name: Optional[str] = None,
    ):
        self.best_runs = BestRuns(
            sweep_dir, validation_metric=validation_metric, load_aposteriori=False
        )
        self.sweep_name = sweep_name

        self.metric_to_plot = metric_to_plot
        metrics = list(self.best_runs.id_to_metrics[self.best_runs.best_run_id].keys())

        self.canonical_stages = self._identify_stages(
            "canonical", metrics, exclude=["diverse_2d", "diverse_3d"]
        )
        self.diverse_2d_stages = self._identify_stages("diverse_2d", metrics)
        self.diverse_3d_stages = self._identify_stages("diverse_3d", metrics)

    def _identify_stages(
        self, partition: str, metrics: List[str], exclude: Optional[List[str]] = None
    ) -> List[str]:
        """Returns matching stages based on word_match.

        Args:
            partition: string such as canonical
            metrics: list of all metrics gathered
            exclude: list of other partition names to exclude

        Returns: the stages in metrics matching the partition
        """
        stages = []
        if exclude is None:
            exclude = []

        plot_metrics = [m for m in metrics if self.metric_to_plot in m]

        for metric in plot_metrics:
            if f"{partition}" in metric and not any(map(metric.__contains__, exclude)):
                stage = metric.replace(f"_{self.metric_to_plot}", "")
                stages.append(stage)
        return stages

    def plots(self) -> List[go.Figure]:
        gap_fig = self.plot_gaps(title_prefix=self.sweep_name)
        metric_fig = self.plot_metric(title_prefix=self.sweep_name)
        plots = [gap_fig, metric_fig]
        return plots

    def plot(self) -> None:
        plots = self.plots()
        for p in plots:
            p.show()

    def plot_metric(self, title_prefix="", title_suffix="") -> go.Figure:
        canonical_metrics = self.to_metric_names(self.canonical_stages)
        bar_canonical = self._plot_metrics_across_best_runs(
            canonical_metrics, self.canonical_stages, title="canonical"
        )
        diverse_2d_metrics = self.to_metric_names(self.diverse_2d_stages)
        bar_diverse_2d = self._plot_metrics_across_best_runs(
            diverse_2d_metrics, self.diverse_2d_stages, title="diverse_2d"
        )
        diverse_3d_metrics = self.to_metric_names(self.diverse_3d_stages)
        bar_diverse_3d = self._plot_metrics_across_best_runs(
            diverse_3d_metrics, self.diverse_3d_stages, title="diverse_3d"
        )

        fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.05)

        fig.add_trace(bar_canonical, row=1, col=1)
        fig.add_trace(bar_diverse_2d, row=2, col=1)
        fig.add_trace(bar_diverse_3d, row=3, col=1)
        if "accuracy" in self.metric_to_plot:
            fig.update_xaxes(tickformat=".0%")

        fig.update_layout(
            height=500, title=f"{title_prefix} {self.metric_to_plot} {title_suffix}"
        )
        return fig

    def plot_gaps(self, title_prefix="") -> go.Figure:
        bar = self.gaps_bar_plot()
        fig = go.Figure(data=bar)
        if "accuracy" in self.metric_to_plot:
            fig.update_xaxes(tickformat=".0%")
        fig.update_layout(
            height=500, title=f"{title_prefix} {self.metric_to_plot} gaps", bargap=0.5
        )
        return fig

    def gaps_bar_plot(self) -> go.Bar:
        gap_means, gap_stderrs, gap_names = self.compute_gaps()

        bar = go.Bar(
            x=gap_means,
            y=gap_names,
            text=[f"{m:0.1%}" for m in gap_means],
            textposition="outside",
            orientation="h",
            marker_color=["gray"] + (len(gap_means) - 1) * ["indianred"],
            error_x=dict(
                type="data",  # value of error bar given in data coordinates
                array=gap_stderrs,
                visible=True,
                color="gray",
                thickness=0.75,
            ),
        )
        return bar

    def _compute_gap(self, metric1_name: str, metric2_name: str) -> Tuple[float, float]:
        """Returns the mean and standard error of the difference between metric1 and metric2"""
        gap = np.array(
            self.best_runs.get_best_runs_metric(f"{metric1_name}_{self.metric_to_plot}")
        ) - np.array(
            self.best_runs.get_best_runs_metric(f"{metric2_name}_{self.metric_to_plot}")
        )
        return gap

    def compute_gaps(self) -> Tuple[List[float], List[float], List[str]]:
        """Returns the means, stderrs, and names of each gap"""
        gap_to_metric_names = OrderedDict(
            [
                ("self-occlusion gap", ("test_diverse_3d", "test_diverse_2d")),
                (
                    "diverse_3d generalization held-out",
                    ("test_diverse_3d", "train_canonical"),
                ),
                (
                    "diverse_3d generalization seen",
                    ("diverse_3d_train_canonical", "train_canonical"),
                ),
                (
                    "diverse_2d generalization held-out",
                    ("test_diverse_2d", "train_canonical"),
                ),
                (
                    "diverse_2d generalization seen",
                    ("diverse_2d_train_canonical", "train_canonical"),
                ),
            ]
        )

        gap_means = [
            self._compute_gap(m1, m2).mean()
            for _, (m1, m2) in gap_to_metric_names.items()
        ]
        gap_stderrs = [
            sem(self._compute_gap(m1, m2))
            for _, (m1, m2) in gap_to_metric_names.items()
        ]
        gap_names = [
            f"<b> {g} </b> <br> ({m1} - {m2})"
            for g, (m1, m2) in gap_to_metric_names.items()
        ]
        return gap_means, gap_stderrs, gap_names

    def to_metric_names(self, stages: List[str]) -> List[str]:
        """Returns the full metric names based on the stages"""
        return [f"{n}_{self.metric_to_plot}" for n in stages]

    def _plot_metrics_across_best_runs(
        self, names: List[str], stages: List[str], title: str = ""
    ) -> go.Bar:
        means = [np.mean(self.best_runs.get_best_runs_metric(name)) for name in names]
        stderr = [sem(self.best_runs.get_best_runs_metric(name)) for name in names]
        bar = go.Bar(
            x=means,
            y=stages,
            name=title,
            text=[f"{m:0.1%}" for m in means],
            textposition="outside",
            orientation="h",
            error_x=dict(
                type="data",  # value of error bar given in data coordinates
                array=stderr,
                visible=True,
                color="gray",
                thickness=0.75,
            ),
        )
        return bar


class CompareModelsPlot:
    """Plots gaps across models"""

    def __init__(
        self,
        model_names: List[str] = [
            "clip",
            "resnet",
            "vit",
            "mlp_mixer",
            "simclr",
            "mae",
        ],
        eval_type: str = "linear_eval",
        training_type: str = "canonical_training",
        sweep_to_dir=SWEEP_TO_DIR,
    ):
        self.model_names = model_names
        self.eval_type = eval_type
        self.training_type = training_type

        self.sweep_to_dir = sweep_to_dir

        self.model_name_to_plots: Dict[str, go.Bar] = self.build_model_name_to_plots()

    def build_model_name_to_plots(self) -> Dict[str, go.Bar]:
        model_name_to_plots = {}
        for model_name in self.model_names:
            sweep_name = f"{model_name}_{self.eval_type}_{self.training_type}"
            plots = Plots(self.sweep_to_dir[sweep_name])
            model_name_to_plots[model_name] = plots
        return model_name_to_plots

    def plot_metrics_per_model(self) -> List[go.Figure]:
        figures = []
        for model_name, plot in self.model_name_to_plots.items():
            hyperparameters = plot.best_runs.best_run_hyperparameters
            learning_rate, optimizer = (
                hyperparameters["learning_rate"],
                hyperparameters["optimizer"],
            )
            fig = plot.plot_metric(
                title_prefix=f"{model_name}",
                title_suffix=f"lr={learning_rate} ({optimizer})",
            )
            figures.append(fig)
        return figures

    def plot(self) -> go.Figure:
        fig = go.Figure()
        for model_name in self.model_names:
            bar_plot = self.single_bar_plot(model_name)
            fig.add_trace(bar_plot)
        title = (
            f"{self.eval_type.replace('_', ' ')} {self.training_type.replace('_', ' ')}"
        )
        fig.update_layout(
            barmode="group",
            height=700,
            title=title,
            template="plotly_white",
        )
        fig.update_xaxes(range=[-0.8, 0.0])
        fig.update_layout(xaxis_tickformat=",.0%", xaxis={"tickfont": dict(size=20)})
        return fig

    def single_bar_plot(self, model_name: str) -> go.Bar:
        plots = self.model_name_to_plots[model_name]
        gap_means, gap_stderrs, gap_names = plots.compute_gaps()

        bar = go.Bar(
            x=gap_means,
            y=gap_names,
            orientation="h",
            name=model_name,
            error_x=dict(
                type="data",  # value of error bar given in data coordinates
                array=gap_stderrs,
                visible=True,
                color="gray",
                thickness=0.75,
            ),
            # text=[f" {m:0.1%} " for m in gap_means],
            # textposition="outside",
        )
        return bar


class PerAnglesPlots(Plots):
    def __init__(self, *args):
        super().__init__(*args)

    def to_result_names(self, stages: List[str]) -> List[str]:
        """Returns the result names (atm they are = names of the stages)"""
        return [f"{n}" for n in stages]

    def plot(self) -> None:
        canonical_results = self.to_result_names(self.canonical_stages)
        self._plot_results_across_best_runs(canonical_results, metric="top_1_accuracy")

        diverse_2d_results = self.to_result_names(self.diverse_2d_stages)
        self._plot_results_across_best_runs(diverse_2d_results, metric="top_1_accuracy")

        # TODO: this only works at the moment because the 3 values are the same in all axes (and we select on "pose_y")
        diverse_3d_results = self.to_result_names(self.diverse_3d_stages)
        self._plot_results_across_best_runs(diverse_3d_results, metric="top_1_accuracy")

    def _plot_results_across_best_runs(
        self,
        names: List[str],
        metric: str,
    ) -> None:

        _, axes = plt.subplots(1, len(names), figsize=(15, 5))

        for i, name in enumerate(names):
            df = self._result_to_df(name)
            self.mean_bin_and_plot(df, metric, name, axes[i])

    def _result_to_df(self, name):

        results_all = self.best_runs.get_best_runs_result_apost(name)

        dfs = []
        runs = []
        for run in range(len(results_all)):
            df = pd.DataFrame.from_dict(results_all[run])
            df["pose_x"] = df.fov.map(lambda x: x[0])
            df["pose_y"] = df.fov.map(lambda x: x[1])
            df["pose_z"] = df.fov.map(lambda x: x[2])
            df["y_hat"] = df.y_hat
            df.drop("fov", inplace=True, axis=1)
            df = df.rename(
                columns={"y": "ground_truth"}
            )  # "y" is misleading when we check
            df["pred"] = df.y_hat.map(lambda x: np.argmax(x, axis=-1)).astype(
                float
            )  # converts logits to prediction
            df["top_1_accuracy"] = df[f"pred"] == df["ground_truth"]
            df = df.set_index("image_path")
            runs.append(f"{run}_")
            dfs.append(df)

        df = pd.concat([df.add_prefix(run) for run, df in zip(runs, dfs)], axis=1)

        # Check some columns with specific keys (corresponding to data ground_truth FOV) should be the same
        check_keys = ["pose_x", "pose_y", "pose_z", "ground_truth"]
        self.check_all_same(df, check_keys)

        # Replace with first value if check is passed
        df = self.remove_duplicates(df, check_keys)

        return df

    def check_all_same(self, df, check_keys):
        for key in check_keys:
            list_col = [col for col in df.columns if key in col]
            assert df[list_col].eq(df[list_col].iloc[:, 0], axis=0).all().values.all()

    def remove_duplicates(self, df, check_keys):

        for key in check_keys:
            # save the first appearance of that column
            save_col = df[f"0_{key}"]
            # remove all appearances
            list_col = [col for col in df.columns if key in col]
            df.drop(list_col, inplace=True, axis=1)
            # set new column with the saved value
            df[key] = save_col

        return df

    def mean_bin_and_plot(self, df, metric, name, ax):

        # List of columns with the metric, for all best run ids
        list_cols = [f"{i}_{metric}" for i in range(len(self.best_runs.best_run_ids))]

        bins = {}
        err = {}
        prev_tr = -1
        range_transformation = range(10, 190, 10)
        max_transfo = 360
        for value in range_transformation:

            pose_y_considered, df_r_prime = self.return_sub_df(
                value, df, prev_tr, max_transfo
            )

            if not len(pose_y_considered) == 0:
                # average over SAMPLES for each run
                mean_n = df_r_prime[list_cols].astype(float).mean(axis=0)
                # average and std over RUNS
                bins[f"{pose_y_considered}"] = mean_n.mean(axis=0)
                err[f"{pose_y_considered}"] = mean_n.std(
                    axis=0
                )  # Normalized by N-1 by default in pandas
            prev_tr = value

        mean_ = np.array(list(bins.values()))
        std_ = np.array(list(err.values()))

        y_min = mean_.min() - std_.max() - 0.05
        y_max = mean_.max() + std_.max() + 0.05
        n_bins = len(bins.keys())
        ax.bar(range(n_bins), mean_, yerr=std_)
        ax.set_ylim([y_min, y_max])
        ax.set_xticks(np.arange(n_bins))
        ax.set_xticklabels(list(bins.keys()), ha="center", rotation=90)
        ax.set_ylabel(metric)
        ax.set_xlabel("Angle of rotation values")
        ax.set_title(name)

    def return_sub_df(self, value, df, prev_tr, max_transfo):

        df_r = df[
            ~((df["pose_y"] <= prev_tr) | ((max_transfo - df["pose_y"]) <= prev_tr))
        ]  # dismiss previously considered points
        df_r = df_r.reset_index()
        df_r_prime = df_r[
            (df_r["pose_y"] <= value) | ((max_transfo - df_r["pose_y"]) <= value)
        ]
        pose_y_considered = [int(i) for i in df_r_prime["pose_y"].unique()]

        return pose_y_considered, df_r_prime

    def plot_gaps(self) -> None:

        fig, axes = plt.subplots(5, 1, figsize=(15, 60))
        metric = "top_1_accuracy"
        self.mean_bin_diff_and_plot(
            "train_diverse_2d", "test_diverse_2d", metric, axes[0], binning=True
        )
        self.mean_bin_diff_and_plot(
            "train_diverse_3d", "test_diverse_3d", metric, axes[1], binning=True
        )
        self.mean_bin_diff_and_plot(
            "test_canonical", "test_diverse_2d", metric, axes[2], binning=False
        )
        self.mean_bin_diff_and_plot(
            "test_canonical", "test_diverse_3d", metric, axes[3], binning=False
        )
        self.mean_bin_diff_and_plot(
            "train_canonical",
            "diverse_2d_train_canonical",
            metric,
            axes[4],
            binning=False,
        )
        self.mean_bin_diff_and_plot(
            "train_canonical",
            "diverse_3d_train_canonical",
            metric,
            axes[4],
            binning=False,
        )
        fig.tight_layout()
        fig.show()

    def mean_bin_diff_and_plot(self, nameA, nameB, metric, ax, binning=False):
        """
        Same as mean_bin_and_plot but as a difference
        """

        dfA = self._result_to_df(nameA)
        dfB = self._result_to_df(nameB)
        title = f"{nameA} - {nameB}"

        # List of columns with the metric, for all best run ids
        list_cols = [f"{i}_{metric}" for i in range(len(self.best_runs.best_run_ids))]

        if binning:
            bins = {}
            err = {}
            prev_tr = -1
            range_transformation = range(10, 190, 10)
            max_transfo = 360
            for value in range_transformation:

                posesA, df_r_primeA = self.return_sub_df(
                    value, dfA, prev_tr, max_transfo
                )
                posesB, df_r_primeB = self.return_sub_df(
                    value, dfB, prev_tr, max_transfo
                )

                if not posesA == posesB:
                    raise ValueError(
                        "Per angle with different bins currently not supported"
                    )

                if not len(posesA) == 0 and not len(posesB) == 0:
                    # average over SAMPLES for each run
                    mean_n_A = df_r_primeA[list_cols].astype(float).mean(axis=0)
                    mean_n_B = df_r_primeB[list_cols].astype(float).mean(axis=0)

                    gap = mean_n_A - mean_n_B

                    # average and std over RUNS
                    bins[f"{posesA}"] = gap.mean(axis=0)
                    # if only 1 run
                    if gap.shape[0] == 1:
                        err[f"{posesA}"] = 0.0
                    else:
                        err[f"{posesA}"] = gap.std(axis=0)

                prev_tr = value

            mean_ = np.array(list(bins.values()))
            std_ = np.array(list(err.values()))

            y_min = mean_.min() - std_.max() - 0.05
            y_max = mean_.max() + std_.max() + 0.05
            n_bins = len(bins.keys())
            ax.bar(range(n_bins), mean_, yerr=std_)
            ax.set_ylim([y_min, y_max])
            ax.set_xticks(np.arange(n_bins))
            ax.set_xticklabels(list(bins.keys()), ha="center", rotation=90)
            ax.set_ylabel(metric)
            ax.set_xlabel("Angle of rotation values")
            ax.set_title(title)

        else:
            # average over SAMPLES for each run
            mean_n_A = dfA[list_cols].astype(float).mean(axis=0)
            mean_n_B = dfB[list_cols].astype(float).mean(axis=0)
            gap = mean_n_A - mean_n_B

            # average and std over RUNS
            mean_ = gap.mean(axis=0)
            # if only 1 run
            if gap.shape[0] == 1:
                std_ = 0.0
            else:
                std_ = gap.std(axis=0)

            ax.bar(range(1), [mean_], yerr=[std_])
            ax.set_ylabel(metric)
            ax.set_title(title)
            ax.set_xticks([])
            ax.set_xticklabels([])
            ax.set_ylabel(metric)
            ax.set_title(title)


class DiverseTrainingPlots:
    """Model performance against varying training diversity"""

    def __init__(
        self,
        model_names: List[str] = [
            "clip",
            "resnet",
            "vit",
            "mlp_mixer",
            "simclr",
            "mae",
        ],
        eval_type: str = "linear_eval",
        training_type: str = "diverse_training",
        run_to_dir=DIVERSE_SWEEP_TO_DIR,
    ):
        self.model_names = model_names
        self.eval_type = eval_type
        self.training_type = training_type
        self.run_to_dir = run_to_dir

        self.model_name_to_run: Dict[str, DiverseRun] = self.build_model_name_to_run()

    def build_model_name_to_run(self) -> Dict[str, DiverseRun]:
        model_name_to_run = dict()
        for model_name in self.model_names:
            run_name = f"{model_name}_{self.eval_type}_{self.training_type}"
            run = DiverseRun(self.run_to_dir[run_name])
            model_name_to_run[model_name] = run
        return model_name_to_run

    def show_model_prop_to_metric(self, model_name: str, metric_name: str):
        run = self.model_name_to_run[model_name]
        for prop, metrics in sorted(run.prop_to_metrics.items()):
            print(prop, np.mean(metrics[metric_name]), sem(metrics[metric_name]))

    def show_model_prop_to_checkpoints(self, model_name: str):
        """returns a list of checkpoitn for the model runs"""
        run = self.model_name_to_run[model_name]
        for prop, ids in sorted(run.prop_to_ids.items()):
            print(prop, run.run_dir, ids)

    def show_model_metrics(self, model_name: str, prop: float):
        run = self.model_name_to_run[model_name]
        metrics = run.prop_to_metrics[prop]
        for metric_name in metrics:
            print(
                f"{metric_name} {np.mean(metrics[metric_name]):.2%} {sem(metrics[metric_name]):.2%}"
            )

    def make_line_plot(
        self, model_name: str, run: DiverseRun, metric_name: str
    ) -> go.Scatter:
        x = sorted(run.prop_to_metrics.keys())
        y = [np.mean(run.prop_to_metrics[p][metric_name]) for p in x]
        stderr = [sem(run.prop_to_metrics[p][metric_name]) for p in x]
        scatter = go.Scatter(
            x=x,
            y=y,
            error_y=dict(type="data", array=stderr, visible=True),
            name=model_name,
        )
        return scatter

    def plot(self, metric_name: str = "test_diverse_2d_top_1_accuracy") -> go.Figure:
        fig = go.Figure()
        for model_name, run in self.model_name_to_run.items():
            fig.add_trace(self.make_line_plot(model_name, run, metric_name))

        fig.update_layout(
            title=f"{self.eval_type} over varying training diversity",
            xaxis_title="proportion of diverse instances",
            yaxis_title=f"{metric_name}",
        )
        return fig


class MakeTable:
    def __init__(
        self,
        model_to_dir: dict,
        metric_names=[
            "val_diverse_2d_top_1_accuracy",
            "test_diverse_2d_top_1_accuracy",
            "diverse_2d_train_canonical_top_1_accuracy",
            "val_diverse_3d_top_1_accuracy",
            "test_diverse_3d_top_1_accuracy",
            "diverse_3d_train_canonical_top_1_accuracy",
            "val_canonical_top_1_accuracy",
            "test_canonical_top_1_accuracy",
        ],
        eval_type: str = "linear_eval",
    ):
        """Loads last checkpoint for each model and prints metrics"""
        self.eval_type = eval_type
        self.metric_names = metric_names
        self.model_to_dir = model_to_dir

        self.model_name_to_metrics: Dict[str, dict] = self.build_model_name_to_metrics()
        self.df = pd.DataFrame.from_dict(self.model_name_to_metrics, orient="index")[
            metric_names
        ]

    def build_model_name_to_metrics(self) -> Dict[str, dict]:
        model_name_to_metrics = dict()
        for model_name, run_dir in self.model_to_dir.items():
            run_metrics = dict()

            train_metrics = BestRuns.load_metrics_json(run_dir, prefix="train_")
            eval_metrics = BestRuns.load_metrics_json(run_dir, prefix="eval_")

            run_metrics.update(train_metrics)
            run_metrics.update(eval_metrics)
            model_name_to_metrics[model_name] = run_metrics
        return model_name_to_metrics


class MAEBarPlot:
    lie_model = "MAE Lie"
    baseline_model = "MAE"
    evaluation_type = "finetuning"

    baseline_color = "rgb(211,211,211)"
    lie_color = "rgb(57,116,246)"

    def __init__(self, model_runs: List[ModelRuns]):
        self.model_runs = model_runs
        self.mae_runs = self.filter_runs(model_runs)

        self.lie_table = make_results.LieTable(model_runs=self.mae_runs)

        self.metric_names = self.lie_table.table_metrics
        self.metric_display_names = self.lie_table.metric_display_names

    def filter_runs(self, model_runs):
        filtered_runs = []
        for run in model_runs:
            if (
                self.is_relevant_model(run)
                and run.diverse_proportion == 0.5
                and run.eval_type == self.evaluation_type
            ):
                filtered_runs.append(run)
        # sort
        filtered_runs = sorted(filtered_runs, key=lambda x: x.name, reverse=False)
        return filtered_runs

    def is_relevant_model(self, run):
        if run.name == self.lie_model or run.name == self.baseline_model:
            return True
        return False

    def single_bar_plot(self, model_results: make_results.ModelResults) -> go.Bar:
        means = [model_results.mean(m) for m in self.metric_names]
        stderrs = [model_results.stderr(m) for m in self.metric_names]
        color = (
            self.lie_color
            if model_results.name == self.lie_model
            else self.baseline_color
        )

        bar = go.Bar(
            x=means,
            y=self.metric_display_names,
            orientation="h",
            name=model_results.name,
            marker_color=color,
            error_x=dict(
                type="data",  # value of error bar given in data coordinates
                array=stderrs,
                visible=True,
                color="black",
                thickness=2.0,
                # text=[f"    {m:0.1%}" for m in means],
                # textposition="outside",
            ),
        )
        return bar

    def bars(self) -> go.Figure:
        fig = go.Figure()
        for model_results in self.lie_table.model_results:
            bar = self.single_bar_plot(model_results)
            fig.add_trace(bar)

        fig.update_traces(cliponaxis=False)
        fig.update_layout(
            template="plotly_white",
            bargap=0.3,
            width=1200,
            height=800,
            font={"size": 16},
            yaxis=dict(tickfont=dict(size=22)),
        )
        fig.update_layout(xaxis_tickformat=",.0%", xaxis={"tickfont": dict(size=20)})
        return fig


class SimCLRBarPlot(MAEBarPlot):
    lie_model = "SimCLR Lie"
    baseline_model = "SimCLR"
    evaluation_type = "linear_eval"

    lie_color = "rgb(255,140,0)"
