import json
from pathlib import Path
from tqdm import tqdm
import concurrent.futures
import numpy as np
import argparse
from audiocraft.modules.conditioners import AlphabetTokenizer, BPETokenizer
from functools import partial
import matplotlib.pyplot as plt
import seaborn as sns


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("json_path_base", type=Path)
    parser.add_argument("--min_duration", type=float, default=3.0)
    parser.add_argument("--max_duration", type=float, default=90.0)
    parser.add_argument(
        "--g2p_tokenizer", default="bpe", type=str, help="G2P tokenizer."
    )
    parser.add_argument(
        "--tokenizer_file",
        default=None,
        help="Path to tokenizer file of BPETokenizer.",
        type=str,
    )
    parser.add_argument("--update_weight", action="store_true")
    return parser.parse_args()


def process_json_data(
    json_path, tokenizer,
    counts=None, bin_edges=None,
    min_duration=3.0, max_duration=90.0,
    update_weight=False
):
    try:
        with open(json_path, "r") as g:
            json_data = [json.loads(_d) for _d in g.readlines()]
        new_json_data = []

        durations_local = []
        text_length_local = []

        for data in json_data:
            durations_local.append(data["duration"])
            if counts is not None and (update_weight is True or "weight" not in data):
                assert "phoneme" in data
                _duration = data["duration"]
                if _duration < min_duration or _duration > max_duration:
                    weight = None
                else:
                    weight = 1/counts[np.digitize(_duration, bin_edges, right=False)-1]
                    weight *= _duration
                data["weight"] = weight
                new_json_data.append(data)
            if "phoneme" not in data:
                text = data["transcript"]
                text, _ = tokenizer([text])
                assert len(text) == 1
                text = ",".join([str(p) for p in text[0].numpy()])
                data["phoneme"] = text
                new_json_data.append(data)
            else:
                text = data["phoneme"]
            text_length_local.append(len(text.split(",")))

        if len(new_json_data) > 0:
            assert len(new_json_data) == len(json_data)
            with open(json_path, "w") as g:
                for data in new_json_data:
                    g.write(json.dumps(data) + "\n")
        return durations_local, text_length_local
    except Exception as e:
        print(e)
        print(json_path)
        return None


def save_histogram(data, output_path):
    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(10, 6))
    sns.histplot(
        x=data, bins=100,
    )
    plt.title("Histogram")
    plt.ylabel("Number of samples")
    plt.savefig(output_path)
    plt.close()


if __name__ == "__main__":
    args = get_args()

    if args.g2p_tokenizer == "alphabet":
        tokenizer = AlphabetTokenizer(use_g2p=True)
    elif args.g2p_tokenizer == "bpe":
        tokenizer = (
            BPETokenizer(use_g2p=True)
            if args.tokenizer_file is None
            else BPETokenizer(use_g2p=True, tokenizer_file=args.tokenizer_file)
        )
    else:
        raise ValueError(f"Unknown tokenizer tokenizer: {args.g2p_tokenizer}")

    durations = []
    text_length = []

    json_pathes = list(args.json_path_base.glob("*.json"))

    _process_data = partial(process_json_data, tokenizer=tokenizer)
    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = list(tqdm(executor.map(_process_data, json_pathes), total=len(json_pathes)))

    for result in results:
        if result is None:
            continue
        durations.extend(result[0])
        text_length.extend(result[1])

    # reduce durations
    durations = np.array(durations)
    text_length = np.array(text_length)
    available_durations_idx = (durations >= args.min_duration) & (durations <= args.max_duration)
    print(f"Number of samples: {len(durations)}")
    durations = durations[available_durations_idx]
    text_length = text_length[available_durations_idx]
    print(f"Number of samples after filtering by duration: {len(durations)}")

    for ratio in [0.01, 0.02, 0.03, 0.04, 0.05]:
        idx = np.argsort(text_length)
        idx = idx[int(len(idx) * ratio) : int(len(idx) * (1 - ratio))]
        print(f"Ratio: {ratio}")
        print(f"Max text length: {np.max(text_length[idx])}")
        print(f"Min text length: {np.min(text_length[idx])}")
        print(f"Max duration: {np.max(durations[idx])}")
        print(f"Min duration: {np.min(durations[idx])}")
        print(f"Sum of durations: {np.sum(durations[idx]) / 3600:.2f} hours")

    print("Choose the ratio:")
    target_ratio = float(input())
    text_idx = np.argsort(text_length)
    text_idx = text_idx[int(len(text_idx) * target_ratio) : int(len(text_idx) * (1 - target_ratio))]
    durations = durations[text_idx]
    text_length = text_length[text_idx]
    print(f"Max text length: {np.max(text_length)}")
    print(f"Min text length: {np.min(text_length)}")
    print(f"Max duration: {np.max(durations)}")
    print(f"Min duration: {np.min(durations)}")
    print(f"Sum of durations: {np.sum(durations) / 3600:.2f} hours")

    print("Save histograms")
    save_histogram(durations, f"durations_histogram_{args.json_path_base.stem}.png")
    save_histogram(text_length, f"text_length_histogram_{args.json_path_base.stem}.png")

    counts, bin_edges = np.histogram(durations, bins=100)
    bin_edges[-1] += 0.1

    print("Calc weights")
    _process_data = partial(
        process_json_data, tokenizer=None, counts=counts, bin_edges=bin_edges,
        min_duration=args.min_duration, max_duration=args.max_duration,
        update_weight=args.update_weight
    )
    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = list(executor.map(_process_data, json_pathes))

    pbar = tqdm(total=len(results))
    for result in results:
        pbar.update(1)
