import concurrent
import functools
import gc
import json
import logging
import multiprocessing
import numpy as np
import os
import pandas as pd
import pathlib
import pdb
import sys
import warnings

warnings.filterwarnings("ignore")


from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import as_completed
from fastparquet import ParquetFile
from functools import partial
from generate import inflate_edit_path
from helpers import *
from pylint import run_pylint
from tqdm import tqdm

####
REPO_NAME = "lintseq_submission"
base_path = str(pathlib.Path().resolve())
PROJECT_PATH = base_path[: base_path.rfind(REPO_NAME) + len(REPO_NAME)]
####
sys.path.insert(0, os.path.join(PROJECT_PATH, "src"))
from utils import *


def stop_process_pool(executor):
    for pid, process in executor._processes.items():
        process.terminate()
    executor.shutdown()


def worker_fn(did, datum, minimal_edit_path_length=8, diff_token=DIFF_TOKEN):
    code_as_text = datum["content"]
    try:
        edit_path = lintseq(
            code_as_text,
            children_per_round=1,
            top_k=1,
            max_population_size=1,
            max_depth=512,
            indent_bias_sampling_factor=1,
            ignore_imports=False,
            ignore_comments=True,
            ignore_global_defs=True,
            verbose=False,
            ignore_init_errors=False,
        )
    except:
        return None

    if edit_path is None:
        return None

    _, diff_seq = inflate_edit_path(code_as_text, edit_path)
    diff_seq = [d for d in diff_seq if len(d) > 0]

    if len(diff_seq) < minimal_edit_path_length:
        return None

    diff_text = ("\n" + diff_token).join(
        [""] + [edit for edit in diff_seq + [""] + [""]]
    )
    outline = {
        "metadata": {k: datum[k] for k in datum.keys() if not k == "content"},
        "text": diff_text,
        "source": "the-stack",
        "id": did,
    }

    return outline


def main(args, diff_token=DIFF_TOKEN, minimal_edit_path_length=10):
    source_files = [fn for fn in os.listdir(args.path_to_pt_data)]
    source_files.sort()
    out_files = [
        os.path.join(args.path_to_linted_pt_data, fn.replace(".parquet", ".jsonl"))
        for fn in source_files
    ]
    n_workers = 60

    f = source_files[args.file_id]
    save_fn = out_files[args.file_id]
    j = args.group_id

    pf = ParquetFile(os.path.join(args.path_to_pt_data, f))
    df = [group for group in pf.iter_row_groups()][j]
    total_row_groups = len(pf.row_groups)
    print(f"processing {f} row group {args.group_id} / {total_row_groups}")

    futures = []
    results = []
    with concurrent.futures.ProcessPoolExecutor(max_workers=n_workers) as executor:
        with tqdm(total=len(df)) as pbar:
            for line_id in range(len(df)):
                datum = dict(df.iloc[line_id, :])
                datum = {
                    k: v.item()
                    if type(v)
                    in [
                        np.float64,
                        np.float128,
                        np.float32,
                        np.int64,
                        np.int16,
                        np.int32,
                        np.int8,
                    ]
                    else v
                    for k, v in datum.items()
                }
                pdatum = {}
                for k, v in datum.items():
                    try:
                        json.dumps(v)
                    except:
                        continue
                    pdatum[k] = v

                did = f"{f[:f.rfind('.')]}_{j}_{line_id}"

                wargs = (did, pdatum, minimal_edit_path_length, diff_token)
                future = executor.submit(worker_fn, *wargs)
                futures.append(future)
                pbar.update(1)

        for future in tqdm(as_completed(futures), total=len(futures)):
            try:
                outline = future.result(timeout=2)
                if not outline is None:
                    outline = json.dumps(outline) + "\n"
                    with open(save_fn, "a") as outfile:
                        outfile.write(outline)
            except:
                stop_process_pool(executor)
                return


def parse_args():
    parser = ArgumentParser()
    parser.add_argument(
        "--path_to_pt_data",
        type=str,
    )
    parser.add_argument(
        "--path_to_linted_pt_data",
        type=str,
    )
    parser.add_argument("--group_id", default=0, type=int)
    parser.add_argument("--file_id", default=0, type=int)
    args = parser.parse_args()
    os.makedirs(args.path_to_linted_pt_data, exist_ok=True)
    return args


if __name__ == "__main__":
    args = parse_args()
    main(args)
