import json
import os

import tqdm
from datasets import load_dataset
from torch.utils.data import Dataset


def get_dataset(path_or_id, tokenizer, preprocessing, **kwargs):
    """Return a proper dataset class based on the path that identifies the used dataset."""
    if preprocessing == "rm": 
        return PreferenceDataset(path_or_id, tokenizer, **kwargs)
        
    if path_or_id == "tldr":
        path = "CarperAI/openai_summarize_tldr"
        if preprocessing == "text":
            return tldr_texts(path, **kwargs)
        return TLDRDataset(path, tokenizer, prompt_only=preprocessing == "ppo", **kwargs)
    elif path_or_id == "alpaca":
        DATA_DIR = os.getenv("DATA_DIR", ".")
        path = os.path.join(DATA_DIR, "data/datasets/alpaca/alpaca_data.json")
        if preprocessing == "text":
            return alpaca_texts(path, **kwargs)
        return AlpacaDataset(path, tokenizer, prompt_only=preprocessing == "ppo", **kwargs)
    
    raise ValueError("Invalid path_or_id or preprocessing not `rm`")
    
def format_alpaca_sample(sample):
    instruction_input = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input}

### Response:
"""
    instruction_only = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
"""
    if "input" in sample and sample["input"]:
        return instruction_input.format(instruction=sample["instruction"], input=sample["input"])
    else:
        return instruction_only.format(instruction=sample["instruction"])

class AlpacaDataset(Dataset):
    def __init__(self, dataset_path, tokenizer, split="train", max_length=1024, prompt_only=False):
        if split == "train":
            start_idx = None
            end_idx = -2000
        else:
            start_idx = -2000
            end_idx = None
        self.posts = list(map(lambda x: format_alpaca_sample(x) + (x["output"] if not prompt_only else ""), tqdm.tqdm(json.load(open(dataset_path))[start_idx: end_idx])))
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.encoded_posts = self.tokenizer(self.posts, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")

    def __len__(self):
        return len(self.posts)

    def __getitem__(self, idx):
        return {
            "input_ids": self.encoded_posts["input_ids"][idx],
            "attention_mask": self.encoded_posts["attention_mask"][idx],
        }


class TLDRDataset(Dataset):
    def __init__(self, dataset_path, tokenizer, split="train", max_length=1024, prompt_only=False):
        self.posts = [(sample["prompt"] + sample["label"]) if not prompt_only else sample["prompt"] for sample in tqdm.tqdm(load_dataset(dataset_path, split=split))]
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.encoded_posts = self.tokenizer(self.posts, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")

    def __len__(self):
        return len(self.posts)

    def __getitem__(self, idx):
        return {
            "input_ids": self.encoded_posts["input_ids"][idx],
            "attention_mask": self.encoded_posts["attention_mask"][idx],
        }


class PreferenceDataset(Dataset):
    def __init__(self, comparison_path, tokenizer, max_length=1024, label_name="claude_preference"):

        self.tokenizer = tokenizer
        self.chosen = []
        self.rejected = []
        self.max_length = max_length

        if type(comparison_path) is str:
            comparison_path = [comparison_path]
        for cp in comparison_path:
            dataset = json.load(open(cp))
            for sample in tqdm.tqdm(dataset, desc=cp):
                if sample[label_name] is None:
                    continue
                choice = int(sample[label_name])
                if choice not in [0, 1] or sample["label"] == sample["output"]:
                    continue
                chosen, rejected = ("label", "output") if choice == 0 else ("output", "label")
                self.chosen.append(sample["prompt"].strip() + " " + sample[chosen].strip() + tokenizer.eos_token)
                self.rejected.append(sample["prompt"].strip() + " " + sample[rejected].strip() + tokenizer.eos_token)
        
        print(f"Loaded the dataset, {len(self.chosen)} samples")

    def __len__(self):
        return len(self.chosen)

    def __getitem__(self, idx):
        chosen_tensor = self.tokenizer(self.chosen[idx], truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
        rejected_tensor = self.tokenizer(self.rejected[idx], truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
        return {
            "input_ids_chosen": chosen_tensor["input_ids"][0],
            "attention_mask_chosen": chosen_tensor["attention_mask"][0],
            "input_ids_rejected": rejected_tensor["input_ids"][0],
            "attention_mask_rejected": rejected_tensor["attention_mask"][0]
        }


def alpaca_texts(dataset_path, split="train", **kwargs):
    if split == "train":
        start_idx = None
        end_idx = -2000
    else:
        start_idx = -2000
        end_idx = None
    return list(map(lambda x: {"prompt": format_alpaca_sample(x), "label": x["output"]}, json.load(open(dataset_path))[start_idx: end_idx]))


def tldr_texts(dataset_path, split="train", **kwargs):
        return load_dataset(dataset_path, split=split).to_list()
