import json
import pandas as pd
import random

from collections import Counter
from datasets import load_dataset
from typing import List, Union
from unidiff import PatchSet

"""
Usage
- v1: python -c "from generate_data import generate_easy_v1; generate_easy_v1()"
- v2: python -c "from generate_data import generate_easy_v2; generate_easy_v2()"
- tiny: python -c "from generate_data import generate_easy_tiny; generate_easy_tiny()"
- med: python -c "from generate_data import generate_easy_med; generate_easy_med()"
- 40: python -c "from generate_data import generate_40_v1; generate_40_v1()"
- 40v2: python -c "from generate_data import generate_40_v2; generate_40_v2()"
"""

# Instance ID Lists
removals = [
    "marshmallow-code__marshmallow-1745",
    "pvlib__pvlib-python-1854",
    "pydicom__pydicom-1315",
    "pydicom__pydicom-1694",
    "pylint-dev__astroid-1367",
    "pylint-dev__astroid-1614",
    "pylint-dev__astroid-1866",
    "pylint-dev__astroid-1959",
    "pylint-dev__astroid-1962",
    "pylint-dev__astroid-2206",
    "sqlfluff__sqlfluff-909",
]

additions = [
    "pydicom__pydicom-1139",
    "pydicom__pydicom-1069",
    "pvlib__pvlib-python-1701",
]

tiny = [
    "pvlib__pvlib-python-1224",
    "pvlib__pvlib-python-1239",
    "pvlib__pvlib-python-1738",
    "pydicom__pydicom-1256",
    "sqlfluff__sqlfluff-2419",
]

medium = [
    "pvlib__pvlib-python-1191",
    "pvlib__pvlib-python-1072",
    "pvlib__pvlib-python-1606",
]

def get_hf_data(split: str):
    data = load_dataset("princeton-nlp/SWE-bench", split=split)
    data = data.to_dict()

    transform = []
    for idx in range(len(data["instance_id"])):
        temp = {}
        for key in data.keys():
            temp[key] = data[key][idx]
            # Handle deserialization from string => List for certain values
            if key in ["FAIL_TO_PASS", "PASS_TO_PASS"]:
                temp[key] = json.loads(temp[key])
        transform.append(temp)

    return transform


def generate_easy_v1():
    # Load dev split and convert to list of dicts
    data = load_dataset("princeton-nlp/SWE-bench")
    dev_df = pd.DataFrame.from_records(data["dev"])
    dev_json = json.loads(dev_df.to_json())

    dev = []
    for idx in range(dev_df.shape[0]):
        dev_inst = {}
        for k in dev_json.keys():
            dev_inst[k] = dev_df.iloc[idx, :][k]
            # Handle deserialization from string => List for certain values
            if k in ["FAIL_TO_PASS", "PASS_TO_PASS"]:
                dev_inst[k] = json.loads(dev_inst[k])
        dev.append(dev_inst)

    # Create easy set based on filtering criteria
    easy_set = []
    for t in dev:
        # Problem Statement with 15-30 lines
        ps = t["problem_statement"].split("\n")
        if not (len(ps) >= 15 and len(ps) <= 30):
            continue
        # Patch with < 50 lines
        patch = t["patch"].split("\n")
        if not (len(patch) < 50):
            continue
        t["query"] = t["problem_statement"]
        t["task_id"] = t["instance_id"]
        easy_set.append(t)

    # Save to file
    with open("swe-bench-dev-easy.json", "w") as f:
        json.dump(easy_set, fp=f)

    print("Dev Easy v1")
    print(f"Counts: {Counter([t['repo'] for t in easy_set])}")
    print(f"Total: {len(easy_set)}")


def generate_easy_v2():
    """
    Filtering criteria:
    * At least 150 words
    * Must not have hyperlinks (`http`) in the issue text
    * Must have changed 2 or fewer files
    * Must have 50 or fewer lines in the gold diff
    """
    data = get_hf_data("dev")
    keep = []
    for d in data:
        if d["instance_id"] in removals:
            continue
        if d["instance_id"] in additions:
            keep.append(d)
            continue

        issue_len = len(d["problem_statement"].split())
        diff_obj = PatchSet(d["patch"])
        total_adds = sum([x.added for x in diff_obj])
        total_removes = sum([x.removed for x in diff_obj])

        if issue_len < 150:
            continue
        if "http" in d["problem_statement"]:
            continue
        if (
            len(diff_obj.modified_files + diff_obj.added_files + diff_obj.removed_files)
            > 2
        ):
            continue
        if total_adds + total_removes > 50:
            continue
        # print(total_adds + total_removes)
        keep.append(d)

    print("Dev Easy v2")
    print(f"Counts: {Counter([t['repo'] for t in keep])}")
    print(f"Total: {len(keep)}")

    path = "swe-bench-dev-easy-v2.json"

    with open(path, "w") as f:
        json.dump(keep, fp=f)


def generate_data_helper(
    name: str,
    keep_ids: Union[List, int],
    exclude_ids: List = [],
    seed: int = None,
    split: str = "dev",
):
    data = get_hf_data(split)
    if seed is not None:
        random.shuffle(data)
    data = [d for d in data if d["instance_id"] not in exclude_ids]
    keep = []
    if isinstance(keep_ids, int):
        keep = data[:keep_ids]
    elif isinstance(keep_ids, List):
        for d in data:
            if d["instance_id"] in keep_ids:
                keep.append(d)

    print(f"Dev Easy {name}")
    print(f"Counts: {Counter([t['repo'] for t in keep])}")
    print(f"Total: {len(keep)}")
    path = f"swe-bench-{split}-{name}.json"
    if seed is not None:
        path = path.replace(".json", f"-seed{seed}.json")
    with open(path, "w") as f:
        json.dump(keep, fp=f)

def generate_easy_tiny():
    generate_data_helper("easy-tiny", keep_ids=list(set(tiny)))

def generate_easy_med():
    generate_data_helper("easy-med", keep_ids=list(set(tiny + additions + medium)))

def generate_40_v1():
    generate_data_helper("40", keep_ids=40, seed=24)

def generate_40_v2():
    data_40_v1 = json.load(open("swe-bench-dev-40-seed24.json"))
    ids_40_v1 = [x["instance_id"] for x in data_40_v1]
    data_v2 = json.load(open("swe-bench-dev-easy-v2.json"))
    ids_v2 = [x["instance_id"] for x in data_v2]
    generate_data_helper(
        "40",
        keep_ids=40,
        seed=25,
        exclude_ids=tiny + medium + ids_v2 + ids_40_v1,
    )