import argparse
import requests
from datasets import load_dataset, DatasetDict, Dataset
import pandas as pd
import os
import random
from domain_datasets import DomainDatasets

def download_and_load_dataset(data_dir, task_name):
    # Define file names
    splits = ['train', 'test', 'dev']
    dataset_dict = {}
    for split in splits:
        # Adjust file names and split names
        file_name = f"{split}.jsonl"
        url = f"{data_dir}{file_name}"
        response = requests.get(url)
        file_path = os.path.join('./temp', f"{task_name}_{split}.jsonl")
        os.makedirs('./temp', exist_ok=True)
        with open(file_path, 'wb') as f:
            f.write(response.content)
        # Load dataset using pandas and convert to Hugging Face dataset
        df = pd.read_json(file_path, lines=True)
        if split == 'dev':  # Rename 'dev' to 'validation'
            split = 'validation'
        dataset_dict[split] = Dataset.from_pandas(df)
    return DatasetDict(dataset_dict)

def download_dataset(task_name):
    if task_name in DomainDatasets:
        data_info = DomainDatasets[task_name]
        return download_and_load_dataset(data_info['data_dir'], task_name)
    else:
        raise ValueError(f"No dataset found for {task_name}")

def split_train_for_validation(dataset, validation_split=0.1):
    validation_size = int(len(dataset) * validation_split)
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    validation_indices = indices[:validation_size]
    train_indices = indices[validation_size:]
    return dataset.select(train_indices), dataset.select(validation_indices)

def load_and_prepare_dataset(task_name, k_shot, val_split, benchmark=None):
    dataset = load_dataset(benchmark, task_name) if benchmark is not None else download_dataset(task_name)
    splits = {}
    unique_labels = set()
    split_keys = ["train"]
    if task_name.lower() == "mnli":
        split_keys.extend(["validation_matched", "validation_mismatched", "test_matched", "test_mismatched"])
    else:
        split_keys.extend(["validation", "test"])

    for split in split_keys:
        if split in dataset:
            subset = dataset[split]
            # Correctly handle the creation of a validation split if it doesn't exist
            if split == "train" and ("validation" not in dataset and task_name.lower() != "mnli") and val_split > 0:
                subset, validation_subset = split_train_for_validation(subset, val_split)
                splits["validation"] = validation_subset
                unique_labels.update(validation_subset['label'])
            if k_shot > 0 and len(subset) > k_shot:
                indices = list(range(len(subset)))
                random.shuffle(indices)
                subset = subset.select(indices[:k_shot])
            splits[split] = subset
            unique_labels.update(subset['label'])
    return splits, len(unique_labels)

def save_dataset_splits(output_path, task_name, dataset_splits, unique_labels_count):
    os.makedirs(os.path.join(output_path, task_name), exist_ok=True)
    split_lengths = {}
    for split, data in dataset_splits.items():
        output_file = os.path.join(output_path, task_name, f"{split}.json")
        data.to_json(output_file)
        split_lengths[split] = len(data)

    summary_data = {
        "Split": list(split_lengths.keys()),
        "Length": list(split_lengths.values()),
        "Labels": [unique_labels_count] * len(split_lengths)
    }

    summary_df = pd.DataFrame(summary_data)
    summary_df.to_csv(os.path.join(output_path, task_name, "summary.csv"), index=False)

def main(args):
    dataset_splits, unique_labels_count = load_and_prepare_dataset(args.task_name, args.k_shot, args.val_split, args.benchmark)
    save_dataset_splits(args.output_path, args.task_name, dataset_splits, unique_labels_count)
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate datasets for different tasks.")
    parser.add_argument("--output_path", type=str, default="./dataset", help="Path to save the datasets")
    parser.add_argument("--task_name", type=str, default="imdb", required=True, help="Name of the task (e.g., 'imdb')")
    parser.add_argument("--k_shot", type=int, default=-1, help="Number of samples per split, -1 for full dataset")
    parser.add_argument("--val_split", type=float, default=0.1, help="Percentage of Train to be split into validation if there is no validation set in original dataset.")
    parser.add_argument("--benchmark", type=str, help="Benchmark of tasks, e.g. glue")

    args = parser.parse_args()
    print(f"==================Processing Dataset {args.task_name}==================")
    main(args)
    print(f"==================Dataset Processing Completed for {args.task_name}==================")