import logging
from collections import namedtuple
from typing import Dict

import click
import torch
import transformers
from transformers import BitsAndBytesConfig
from tqdm import tqdm
from sklearn.cluster import KMeans
import numpy as np

from data_utils import load_dataset_from_data_metastr, ListDataset, save_dataset_to_disk

logger = logging.getLogger(__file__)


class BaseSelector:
    pass


class IdentitySelector(BaseSelector):
    @staticmethod
    @click.command()
    @click.option("-d", "--data_metastr", type=str, required=True)
    @click.option("-o", "--output_name", type=str, required=True)
    @click.option("-s", "--seed", type=int, default=None)
    def cli(data_metastr, output_name, seed):
        ds = load_dataset_from_data_metastr(data_metastr, seed=seed)
        print(ds)
        save_dataset_to_disk(ds, "./data/identity_selector/{}".format(output_name))


class RuleSubmitSelector(BaseSelector):
    @staticmethod
    @click.command()
    @click.option("-d", "--input_path", type=str, required=True)
    @click.option("-o", "--output_name", type=str, required=True)
    @click.option("-n", "--num_examples", type=int, default=2000)
    def cli(input_path, output_name, num_examples):
        import json
        import pandas as pd
        js_ds = json.load(open(input_path))
        ds = pd.DataFrame.from_records(js_ds)
        # 0.0274 -0.0078 * reward +0.4421 * understandability -0.3212 *naturalness -0.1520*coherence
        # rule_func = lambda example: (0.0274 - 0.0078 * example['oa_rm21_pythia_14b'] + 0.4421 * example['unievalunderstandability'] - 0.3212 * example['unievalnaturalness']- 0.1520* example['unievalcoherence'])
        rule_func = lambda example: (0.0274 - 0.0078 * example['reward'] + 0.4421 * example['understandability'] - 0.3212 * example['naturalness']- 0.1520* example['coherence'])
        ds['rule'] = ds.apply(rule_func, axis=1)
        ds = ds.sort_values("rule")
        print(ds)
        ds = ds.iloc[:num_examples,:]
        print(ds)
        ds.to_json('./data/dolly_rule/{}'.format(output_name), orient='records', lines=False)


class IndicatorSelector(BaseSelector):
    @staticmethod
    @click.command()
    @click.option("-d", "--data_metastr", type=str, required=True)
    @click.option("-i", "--indicator_name", type=str, required=True)
    @click.option("-o", "--output_name", type=str, required=True)
    @click.option("-n", "--num_examples", type=int, required=True)
    @click.option("-s", "--split_count", type=int, required=True)
    def cli(data_metastr, indicator_name, output_name, num_examples, split_count):
        # Load dataset from data_metastr
        ds = load_dataset_from_data_metastr(data_metastr)

        # Check if indicator_name exists in ds.column_names
        if indicator_name not in ds.column_names:
            raise ValueError(
                "Indicator {} not found in dataset {}".format(
                    indicator_name, data_metastr
                )
            )

        # Sort the dataset based on the indicator name
        ds = ds.sort(indicator_name)
        len_ds = len(ds)
        range_tuples = []

        # Create range_tuple list based on given num_examples value
        if len_ds <= num_examples:
            range_tuples.append((0, len_ds))
        else:
            max_left_index = len_ds - num_examples
            for i in range(split_count + 1):
                left_index = int(1.0 * i / split_count * max_left_index)
                right_index = left_index + num_examples
                range_tuples.append((left_index, right_index))

        # Deduplicate range_tuples, sort it and apply filter to the dataset for each tuple
        range_tuples = list(set(range_tuples))
        range_tuples.sort()
        for range_left, range_right in range_tuples:
            ds_subset = ds.filter(
                lambda example, idx: range_left <= idx < range_right, with_indices=True
            )
            # Save each subset of dataset to a specific path which is based on input values
            save_dataset_to_disk(
                ds_subset,
                "./data/indicator_selector/{}/{}/{}_{}_of_ne{}_sc{}".format(
                    indicator_name,
                    output_name,
                    range_left,
                    range_right,
                    num_examples,
                    split_count,
                ),
            )
            print(
                f"average {indicator_name} =",
                np.mean([x[indicator_name] for x in ds_subset]),
            )


_SELECTOR_NAMED_MAP = {
    "identity": IdentitySelector,
    "indicator": IndicatorSelector,
    "rule_submit": RuleSubmitSelector,
}


@click.group()
def cli(**kwargs):
    pass


for name, selector in _SELECTOR_NAMED_MAP.items():
    cli.add_command(selector.cli, name=name)


if __name__ == "__main__":
    cli()
