from abc import ABC
from typing import Tuple, Union

from .raw_unit_base import RawUnitBase


class UnitBase(ABC):
    def __init__(self, value, encoder=None):
        self.value = value
        self.encoder = encoder

    @classmethod
    def get_raw_unit_class(cls):
        raise NotImplementedError

    @classmethod
    def new(cls, encoder, *args, **kwargs):
        value = cls.extract(encoder, *args, **kwargs)
        unit = cls(value, encoder=encoder)
        return unit

    @classmethod
    def extract_raw(
        cls, encoder, midi_dir, midi_path,
        pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
    ):
        raw_units = cls.get_raw_unit_class()
        if not isinstance(raw_units, tuple):
            raw_units = (raw_units,)
        raw_unit_class_dict = {}
        for raw_unit_class in raw_units:
            assert issubclass(raw_unit_class, RawUnitBase)
            class_name = raw_unit_class.__name__
            assert class_name.startswith('RawUnit')
            label = class_name[7:]
            raw_unit_class_dict[label] = raw_unit_class
        raw_value_dict = {}
        for label in raw_unit_class_dict:
            raw_unit = raw_unit_class_dict[label]
            raw_value_dict[label] = raw_unit.extract(
                encoder, midi_dir, midi_path,
                pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
            )

        return raw_value_dict

    @classmethod
    def convert_raw_to_value(
        cls, raw_data, encoder, midi_dir, midi_path,
        pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
    ):
        raise NotImplementedError

    @classmethod
    def extract(
        cls, encoder, midi_dir, midi_path,
        pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
    ):

        return cls.convert_raw_to_value(
            cls.extract_raw(
                encoder, midi_dir, midi_path,
                pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
            ),
            encoder, midi_dir, midi_path,
            pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
        )

    @classmethod
    def repr_value(cls, value):
        return value

    @classmethod
    def derepr_value(cls, rep_value):
        return rep_value

    def get_vector(self, use=True, use_info=None) -> list:

        raise NotImplementedError

    @property
    def vector_dim(self) -> Union[int, Tuple[int, int]]:

        raise NotImplementedError
