import csv

import pandas as pd
from torch import Tensor

from smlm.activations.default_values import COLUMNS
from smlm.activations.io import WriterInterface


def read_csv(filepath: str, columns: list) -> pd.DataFrame:
    if len(columns) != 5:
        raise ValueError(
            "columns must contain the five column names for [frame, x, y, z, n]"
        )
    x = pd.read_csv(filepath)
    x.columns = x.columns.str.replace(" ", "", regex=False).str.lower()
    columns = [e.replace(" ", "").lower() for e in columns]
    x = x[columns]
    x.columns = COLUMNS
    return x


class CSVWriter(WriterInterface):
    def __init__(self, filepath: str, columns: list = COLUMNS):
        self.columns = columns
        self.filepath = filepath

    def open(self):
        self.file = open(self.filepath, mode="w", newline="")
        self.writer = csv.writer(self.file)
        self.writer.writerow(self.columns)

    def close(self):
        if self.file:
            self.file.close()

    def _write(self, data: Tensor):
        data = data.round().cpu().int().tolist()
        self.writer.writerows(data)
