"""Module for generating TableDeltaSpecs to transform a clean table to an unclean table."""

import difflib
import io
import warnings

import pandas as pd
from typing import List, Union

from radar.data import datamodel

# Suppress FutureWarnings
warnings.simplefilter(action="ignore", category=FutureWarning)


def generate_transform_spec(
    df_clean: Union[pd.DataFrame, List[pd.DataFrame]],
    df_perturbed: pd.DataFrame,
    convert_to_float: bool = True,
) -> datamodel.TableDeltaSpec:
    """Generates a TableDeltaSpec to transform a clean table to an unclean table."""
    if isinstance(df_clean, list):
        return generate_transform_spec_helper(
            df_clean[0], df_perturbed, convert_to_float
        )
    else:
        return generate_transform_spec_helper(df_clean, df_perturbed, convert_to_float)


def generate_transform_spec_helper(
    df_clean: pd.DataFrame,
    df_perturbed: pd.DataFrame,
    convert_to_float: bool = True,
) -> datamodel.TableDeltaSpec:
    """Generates a TableDeltaSpec to transform a clean table to an unclean table.

    Args:
      df_clean: The clean DataFrame.
      df_perturbed: The perturbed DataFrame.
      convert_to_float: Whether to convert the DataFrame to float.
    Returns:
      A TableDeltaSpec that describes the changes needed to transform the clean
      table to the perturbed table.
    """
    if convert_to_float:
        for col in df_clean.select_dtypes(include=["int", "float"]).columns:
            df_clean.loc[:, col] = df_clean[col].astype(float)
        for col in df_perturbed.select_dtypes(include=["int", "float"]).columns:
            df_perturbed.loc[:, col] = df_perturbed[col].astype(float)
    df_clean_str_lines = df_clean.to_csv(index=False).strip().splitlines()
    df_perturbed_str_lines = df_perturbed.to_csv(index=False).strip().splitlines()

    header = df_clean_str_lines[0]
    columns = header.split(",")

    sm = difflib.SequenceMatcher(
        None, df_clean_str_lines[1:], df_perturbed_str_lines[1:]
    )
    insert_rows = []
    overwrite_cells = []
    for tag, i1, i2, j1, j2 in sm.get_opcodes():
        if tag == "equal":
            continue
        elif tag == "insert":
            for _, j in enumerate(range(j1, j2)):
                line = df_perturbed_str_lines[j + 1]
                values = pd.read_csv(io.StringIO(f"{header}\n{line}")).iloc[0].to_dict()
                values = {k: None if pd.isna(v) else v for k, v in values.items()}
                insert_rows.append(datamodel.InsertRow(index=j, row=values))
        elif tag == "replace":
            for offset in range(min(i2 - i1, j2 - j1)):
                i = i1 + offset
                line1 = df_clean_str_lines[i + 1]
                line2 = df_perturbed_str_lines[j1 + offset + 1]
                row1 = pd.read_csv(io.StringIO(f"{header}\n{line1}"), dtype=str).iloc[0]
                row2 = pd.read_csv(io.StringIO(f"{header}\n{line2}"), dtype=str).iloc[0]
                for col in columns:
                    val1 = row1[col]
                    val2 = row2[col]
                    if pd.isna(val1) and pd.isna(val2):
                        continue
                    if val1 != val2:
                        overwrite_cells.append(
                            datamodel.OverwriteCell(
                                row=i,
                                col=col,
                                new_value=None if pd.isna(val2) else val2,
                            )
                        )
            for j in range(j1 + (i2 - i1), j2):
                line = df_perturbed_str_lines[j + 1]
                values = pd.read_csv(io.StringIO(f"{header}\n{line}")).iloc[0].to_dict()
                values = {k: (None if pd.isna(v) else v) for k, v in values.items()}
                insert_rows.append(datamodel.InsertRow(index=j, row=values))
    return datamodel.TableDeltaSpec(
        insert_rows=insert_rows, overwrite_cells=overwrite_cells
    )


def apply_transform_spec(
    df_clean: pd.DataFrame, transform_spec: datamodel.TableDeltaSpec
) -> pd.DataFrame:
    """Applies a TableDeltaSpec to a DataFrame."""
    df_perturbed = df_clean.copy(deep=True)
    for cell in transform_spec.overwrite_cells:
        value = "" if cell.new_value is None else cell.new_value
        df_perturbed.at[cell.row, cell.col] = value
    for insert in sorted(transform_spec.insert_rows, key=lambda x: x.index):
        row_df = pd.DataFrame(
            [{col: "" if val is None else val for col, val in insert.row.items()}]
        )
        # Split and concat around the insert point
        top = df_perturbed.iloc[: insert.index]
        bottom = df_perturbed.iloc[insert.index :]
        df_perturbed = pd.concat([top, row_df, bottom], ignore_index=True)
    return df_perturbed
