import datasets
from dateutil import parser
import argparse
import pandas as pd
import numpy as np
from SampleEfficiencyMatrix import SampleEfficiencyMatrix

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--weight_path", type=str, required=True)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--type", type=str, default="min")
    args = parser.parse_args()

    data_path = args.data_path
    output_path = args.output_path
    weight_path = args.weight_path
    batch_size = args.batch_size
    htype = args.type

    dataset = pd.read_parquet(data_path)

    print(dataset)

    weights = np.load(weight_path)

    if htype == "min":
        weights = weights * (-1)

    assert int(np.ceil(len(dataset)/batch_size)) == len(weights), f"ceil({len(dataset)}/{batch_size}) = {int(np.ceil(len(dataset)/batch_size))} != {len(weights)}"

    sem = SampleEfficiencyMatrix(0)
    sem.set_data(weights)
    best_sequence = sem.max_hamilton_cycle()
    if best_sequence is not None:
        best_sequence = best_sequence['path'][:-1]
        print(f"best sequence: {best_sequence}")
    else:
        raise Exception("No hamilton cycle found")

    # best_sequence = np.random.permutation(num_batches)
    order = [0 for _ in range(len(dataset))]
    for i, bid in enumerate(best_sequence):
        start = bid * batch_size
        end = min((bid + 1) * args.batch_size, len(dataset))
        order[start:end] = [i for _ in range(start, end)]
    dataset['order'] = order
    # print(dataset)
    dataset.sort_values('order', inplace=True)
    dataset.reset_index()
    print("sorted batch sequence by hamilton cycle")

    dataset.to_parquet(output_path)
    print(f"saved sorted dataset: {output_path}")
  