import pprint
import numpy as np
import torch
from utils.basic_utils import load_jsonl
from eval.eval import eval_submission
from tqdm import tqdm


class PostProcessorDETR:
    def __init__(self, clip_length=2, min_ts_val=0, max_ts_val=150,
                 min_w_l=2, max_w_l=70, move_window_method="center",
                 process_func_names=("clip_window_l", "clip_ts", "round_multiple")):
        self.clip_length = clip_length
        self.min_ts_val = min_ts_val
        self.max_ts_val = max_ts_val
        self.min_w_l = min_w_l
        self.max_w_l = max_w_l
        self.move_window_method = move_window_method
        self.process_func_names = process_func_names
        self.name2func = dict(
            clip_ts=self.clip_min_max_timestamps,
            round_multiple=self.round_to_multiple_clip_lengths,
            clip_window_l=self.clip_window_lengths
        )

    def __call__(self, lines):
        processed_lines = []
        for line in tqdm(lines, desc=f"convert to multiples of clip_length={self.clip_length}"):
            windows_and_scores = torch.tensor(line["pred_relevant_windows"])
            windows = windows_and_scores[:, :2]
            for func_name in self.process_func_names:
                windows = self.name2func[func_name](windows)
            line["pred_relevant_windows"] = torch.cat(
                [windows, windows_and_scores[:, 2:3]], dim=1).tolist()
            line["pred_relevant_windows"] = [e[:2] + [float(f"{e[2]:.4f}")] for e in line["pred_relevant_windows"]]
            processed_lines.append(line)
        return processed_lines

    def clip_min_max_timestamps(self, windows):
        """
        windows: (#windows, 2)  torch.Tensor
        ensure timestamps for all windows is within [min_val, max_val], clip is out of boundaries.
        """
        return torch.clamp(windows, min=self.min_ts_val, max=self.max_ts_val)

    def round_to_multiple_clip_lengths(self, windows):
        """
        windows: (#windows, 2)  torch.Tensor
        ensure the final window timestamps are multiples of `clip_length`
        """
        return torch.round(windows / self.clip_length) * self.clip_length

    def clip_window_lengths(self, windows):
        """
        windows: (#windows, 2)  np.ndarray
        ensure the final window duration are within [self.min_w_l, self.max_w_l]
        """
        window_lengths = windows[:, 1] - windows[:, 0]
        small_rows = window_lengths < self.min_w_l
        if torch.sum(small_rows) > 0:
            windows = self.move_windows(
                windows, small_rows, self.min_w_l, move_method=self.move_window_method)
        large_rows = window_lengths > self.max_w_l
        if torch.sum(large_rows) > 0:
            windows = self.move_windows(
                windows, large_rows, self.max_w_l, move_method=self.move_window_method)
        return windows

    @classmethod
    def move_windows(cls, windows, row_selector, new_length, move_method="left"):
        """
        Args:
            windows:
            row_selector:
            new_length:
            move_method: str,
                left: keep left unchanged
                center: keep center unchanged
                right: keep right unchanged

        Returns:

        """
        # import ipdb;
        # ipdb.set_trace()
        if move_method == "left":
            windows[row_selector, 1] = windows[row_selector, 0] + new_length
        elif move_method == "right":
            windows[row_selector, 0] = windows[row_selector, 1] - new_length
        elif move_method == "center":
            center = (windows[row_selector, 1] + windows[row_selector, 0]) / 2.
            windows[row_selector, 0] = center - new_length / 2.
            windows[row_selector, 1] = center + new_length / 2.
        return windows

