from typing import List

import pandas as pd

""""
1) Standard:
6 * [n] + [r] + 6 * [n] + [r] + 6 * [n]

2) Agnas
6 * [n0] + [r2] + 6 * [n3] + [r5] + 6 * [n6]
"""


class Decoder:
    def __init__(self, cfg):
        self.cfg = cfg
        if cfg["stats_df_path"] is None:
            raise ValueError("Path to stats_df_path not provided!")
        self.df = pd.read_parquet(cfg["stats_df_path"])

    def decode(self, dec_type: str):
        if dec_type == "standard":
            return self._decode_mean_standard()
        elif dec_type == "agnas":
            return self._decode_mean_agnas()
        else:
            raise ValueError("Unkown decoding type!")

    def _decode_mean_standard(self):
        n = self._decode_cells([0, 1, 3, 4, 6, 7])
        r = self._decode_cells([2, 5])
        return 6 * [n] + [r] + 6 * [n] + [r] + 6 * [n]

    def _decode_mean_agnas(self):
        n0, r2, n3, r5, n6 = (
            self._decode_cells([0]),
            self._decode_cells([2]),
            self._decode_cells([3]),
            self._decode_cells([5]),
            self._decode_cells([6]),
        )
        return 6 * [n0] + [r2] + 6 * [n3] + [r5] + 6 * [n6]

    def _decode_cells(self, cells: List[int]):
        curr_cell = []
        for target in range(4):
            target_df = self.df[
                (self.df.cell_idx.isin(cells)) & (self.df.target == target)
            ]
            target_df = target_df[["source", "op_idx", "coeff"]]
            target_df = target_df.groupby(["source", "op_idx"]).mean().reset_index()
            idx = target_df.groupby("source")["coeff"].idxmax()
            top_ops = target_df.iloc[idx]
            top_two = top_ops.sort_values(by="coeff", ascending=False).head(2)
            curr_cell.append(
                [
                    [int(top_two.source.iloc[0]), int(top_two.op_idx.iloc[0])],
                    [int(top_two.source.iloc[1]), int(top_two.op_idx.iloc[1])],
                ]
            )
        return curr_cell
