import torch

from batlinet.builders import LABEL_ANNOTATORS
from batlinet.data.battery_data import BatteryData

from .base import BaseLabelAnnotator


@LABEL_ANNOTATORS.register()
class RULLabelAnnotator(BaseLabelAnnotator):
    def __init__(self,
                 eol_soh: float = 0.7,
                 pad_eol: bool = True,
                 min_rul_limit: float = 100.0):
        self.eol_soh = eol_soh
        self.pad_eol = pad_eol
        self.min_rul_limit = min_rul_limit

    def process_cell(self, cell_data: BatteryData):
        label, found_eol = 1, False
        Qds = []
        for cycle in cell_data.cycle_data:
            label += 1
            Qd = max(cycle.discharge_capacity_in_Ah)
            Qds.append(Qd / cell_data.nominal_capacity_in_Ah)
            if Qd <= cell_data.nominal_capacity_in_Ah * self.eol_soh:
                found_eol = True
                break

        if not found_eol:
            label = label + 1 if self.pad_eol else float('nan')

        if label <= self.min_rul_limit:
            label = float('nan')

        label = torch.tensor(label)
        Qds = torch.tensor(Qds, dtype=torch.float32).view(1, -1)
        Qds = torch.nn.functional.pad(
            Qds, (0, 2560 - Qds.shape[1]), "constant", 0.7) if Qds.shape[1] < 2560 else Qds[:, :2560]
        return label, Qds[:, ::10]
    
    
