from abc import ABC
import wandb
import pandas as pd
from collections import defaultdict

class SolverBase(ABC):

    def __init__(self):
        self.hist_data = defaultdict(list)

    def step(*args, **kwargs):
        raise NotImplementedError()

    def log_xt_update(self, t, xt_diff):
        BS = xt_diff.size(0)
        xt_diff = xt_diff.view(BS, -1)

        if not isinstance(t, list):
            t = [t] * BS
        self.hist_data["solver/t"].extend(t)
        self.hist_data["solver/xt_diff/mean"].extend(
            xt_diff.mean(1).flatten().detach().cpu().numpy().tolist()
        )
        self.hist_data["solver/xt_diff/sum"].extend(
            xt_diff.sum(1).flatten().detach().cpu().numpy().tolist()
        )

    def on_end(self):
        data = pd.DataFrame(self.hist_data)

        diff_mean_table = wandb.Table(
            data=data.groupby("solver/t").mean().reset_index()
        )
        diff_std_table = wandb.Table(
            data=data.groupby("solver/t").std().reset_index()
        )

        wandb.log({
            "solver/xt_diff/mean": wandb.plot.scatter(
                diff_mean_table,
                x="solver/t",
                y="solver/xt_diff/mean",
                title="Mean Update Norm"
            ),
            "solver/xt_diff/sum": wandb.plot.scatter(
                diff_mean_table,
                x="solver/t",
                y="solver/xt_diff/sum",
                title="Sum Update Norm"
            ),
            "solver/xt_diff/mean_std": wandb.plot.scatter(
                diff_std_table,
                x="solver/t",
                y="solver/xt_diff/mean",
                title="Mean Update Norm STD"
            ),
            "solver/xt_diff/sum_std": wandb.plot.scatter(
                diff_std_table,
                x="solver/t",
                y="solver/xt_diff/sum",
                title="Sum Update Norm STD"
            )

        })