import os
import yaml
import gzip
import shutil
import tarfile
import argparse
import urllib.request
from pathlib import Path

from syntherela.metadata import Metadata
from syntherela.data import load_tables, remove_sdv_columns, save_tables

from gretel_client import configure_session
from gretel_client import create_or_get_unique_project
from gretel_client.config import get_session_config
from gretel_client.rest_v1.api.connections_api import ConnectionsApi
from gretel_client.rest_v1.api.workflows_api import WorkflowsApi
from gretel_client.rest_v1.models import (
    CreateWorkflowRunRequest,
    CreateWorkflowRequest,
)
from gretel_client.workflows.logs import print_logs_for_workflow_run


# Helpers for running workflows from the notebook


def run_workflow(config: str):
    """Create a workflow, and workflow run from a given yaml config. Blocks and
    prints log lines until the workflow reaches a terminal state.

    Args:
        config: The workflow config to run.
    """
    print("Validating actions in the config...")
    config_dict = yaml.safe_load(config)

    for action in config_dict["actions"]:
        print(f"Validating action {action['name']}")
        response = workflow_api.validate_workflow_action(action)
        print(f"Validation response: {response}")

    workflow = workflow_api.create_workflow(
        CreateWorkflowRequest(
            project_id=project.project_guid,
            config=config_dict,
            name=config_dict["name"],
        )
    )

    workflow_run = workflow_api.create_workflow_run(
        CreateWorkflowRunRequest(workflow_id=workflow.id)
    )

    print(f"workflow: {workflow.id}")
    print(f"workflow run id: {workflow_run.id}")

    print_logs_for_workflow_run(workflow_run.id, session)


if __name__ == "__main__":
    args = argparse.ArgumentParser()
    args.add_argument("--dataset-name", type=str, default="Biodegradability_v1")
    args.add_argument("--real-data-path", type=str, default="data/original")
    args.add_argument("--synthetic-data-path", type=str, default="data/synthetic")
    args.add_argument("--connection-uid", type=str, required=True)
    args.add_argument("--model", type=str, required=True, choices=["lstm", "actgan"])
    args.add_argument("--run-id", type=str, default="1")
    args = args.parse_args()

    dataset_name = args.dataset_name
    real_data_path = args.real_data_path
    synthetic_data_path = args.synthetic_data_path
    input_connection_uid = args.connection_uid
    model = args.model
    run_id = args.run_id

    # Log into Gretel
    configure_session(api_key="prompt", cache="yes", validate=True)

    # Load real data
    metadata = Metadata().load_from_json(
        Path(real_data_path) / f"{dataset_name}/metadata.json"
    )
    real_data = load_tables(Path(real_data_path) / f"{dataset_name}", metadata)
    real_data, metadata = remove_sdv_columns(real_data, metadata)
    metadata.validate_data(real_data)

    # ## Designate Project for your Relational Workflow
    table_names = list(real_data.keys())
    dataset_name_gretel = dataset_name.replace("_", "-")

    session = get_session_config()
    connection_api = session.get_v1_api(ConnectionsApi)
    workflow_api = session.get_v1_api(WorkflowsApi)

    project = create_or_get_unique_project(
        name=f"Synthesize-{dataset_name_gretel}-{model}"
    )

    # Configure and Run your Relational Workflow
    # Gretel Workflows provide an easy to use, config driven API for automating and operationalizing synthetic data. A Gretel Workflow is constructed by actions that are composed to create a pipeline for processing data with Gretel. To learn more about Gretel Workflows, check out [our docs](https://docs.gretel.ai/reference/workflows).

    ### Define Source Data via Connector
    connection_type = connection_api.get_connection(input_connection_uid).dict()["type"]

    # ### Define Workflow configuration
    workflow_config = f"""\
    name: {dataset_name_gretel}-{connection_type}-workflow-{model}

    actions:
    - name: {connection_type}-read
        type: {connection_type}_source
        connection: {input_connection_uid}
        config:
        sync:
            mode: full

    - name: model-train-run
        type: gretel_tabular
        input: {connection_type}-read
        config:
        project_id: {project.project_guid}
        train:
            model: "synthetics/tabular-{model}"
            dataset: "{{{connection_type}-read.outputs.dataset}}"
        run:
            num_records_multiplier: 1.0

    """
    print(workflow_config)

    # ### Run Workflow

    run_workflow(workflow_config)

    # ## View Results

    path = f"./data_{dataset_name}_{model}"
    os.makedirs(path, exist_ok=True)

    # Download the output artifacts
    urllib.request.urlretrieve(
        project.get_artifact_link(project.artifacts[-1]["key"]),
        f"./data_{dataset_name}_{model}/workflow-output.tar.gz",
    )

    with gzip.open(
        f"./data_{dataset_name}_{model}/workflow-output.tar.gz", "rb"
    ) as f_in:
        with open(f"./data_{dataset_name}_{model}/workflow-output.tar", "wb") as f_out:
            f_out.write(f_in.read())

    with tarfile.open(f"{path}/workflow-output.tar") as tar:
        tar.extractall(path)

    path_synthetic = (
        f"{synthetic_data_path}/{dataset_name}/GRETEL_{model.upper()}/{run_id}/sample1"
    )
    os.makedirs(path_synthetic, exist_ok=True)
    for table in table_names:
        shutil.copy(f"{path}/synth_{table}.csv", f"{path_synthetic}/{table}.csv")

    # Postprocess synthetic data (some categorical columns are generated as floats)
    def is_float(value):
        if value is None or value == "":
            return False
        try:
            float(value)
            return True
        except:  # noqa: E722
            return False

    synthetic_data = load_tables(Path(path_synthetic), metadata)
    metadata.validate_data(synthetic_data)

    for table in metadata.get_tables():
        table_meta = metadata.get_table_meta(table)
        for column, column_info in table_meta["columns"].items():
            if column_info["sdtype"] == "categorical":
                values = synthetic_data[table][column].unique()
                numeric = False
                for value in values:
                    if is_float(str(value)) and value == value:
                        numeric = True
                        break
                if numeric:
                    synthetic_data[table][column] = synthetic_data[table][
                        column
                    ].astype("object")
                    for i, value in synthetic_data[table][column].items():
                        if value != value or not is_float(str(value)):
                            # skip NaN
                            continue
                        synthetic_data[table].at[i, column] = int(float(value))
                    synthetic_data[table][column] = synthetic_data[table][
                        column
                    ].astype("object")

    # Save synthetic data
    save_tables(synthetic_data, Path(path_synthetic))
