from dataclasses import dataclass
from typing import List
import pandas as pd
import hues

from utils.csv_tool import read_csv


@dataclass
class DataRow:
    id: str
    input: str  # Text features
    target: str  # Classification Label
    set_type: int  # 0: train-set, 1: test-set


class Dataset:
    def __init__(self, path: str):
        self.rows: List[DataRow] = []
        self.train_set: List[DataRow] = []
        self.test_set: List[DataRow] = []

        self.__read_csv(path) # read data from disk

        self.__transfer_2_pandas()

    def __read_csv(self, path: str):
        for idx, line in enumerate(read_csv(path)):
            self.rows.append(DataRow(
                id=str(idx),
                input=line['input'],
                target=line['target'],
                set_type=line['set_type'],
            ))
        for idx, row in enumerate(self.rows):
            if row.set_type == 'train':
                self.train_set.append(row)
            else:
                self.test_set.append(row)
        hues.info(f"Read datasets from [{path}], datasets total size = {len(self.rows)}, train-set size = {len(self.train_set)}, test-set size = {len(self.test_set)}")

    def __transfer_2_pandas(self):
        # transfer to pandas.dataframe, so you can use both pandas or raw list
        self.rows_df: pd.DataFrame = pd.DataFrame(self.rows)
        self.train_set_df: pd.DataFrame = pd.DataFrame(self.train_set)
        self.test_set_df: pd.DataFrame = pd.DataFrame(self.test_set)

    def print_rows(self, limit=3):
        for row in self.rows[:limit]:
            print(row)
