import os
import random

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, LlamaTokenizer

# ====== Your local dataset roots ======
DATA_DIR = "/seu_nvme/ogai/datasets/"
MY_DATA_DIR = "/TO/MAY/PATH/MyDatasets/"

# Default local files for your extra benchmarks (migrated from 【3】)
# NOTE: If any path differs in your environment, change it here only.
DATA_PATHS = {
    # parquet
    "arc_challenge_train": os.path.join(MY_DATA_DIR, "ARC-challenge/train-00000-of-00001.parquet"),
    "arc_challenge_test":  os.path.join(MY_DATA_DIR, "ARC-challenge/test-00000-of-00001.parquet"),

    "piqa_train": os.path.join(MY_DATA_DIR, "PIQA/train-00000-of-00001.parquet"),
    "piqa_test":  os.path.join(MY_DATA_DIR, "PIQA/test-00000-of-00001.parquet"),

    "winogrande_train": os.path.join(MY_DATA_DIR, "WinoG/winogrande_xl/train-00000-of-00001.parquet"),
    "winogrande_test":  os.path.join(MY_DATA_DIR, "WinoG/winogrande_xl/test-00000-of-00001.parquet"),

    # IMPORTANT: your 【3】 has hellaswag path pointing to ARC-easy (likely a bug).
    # Please replace below with your real HellaSwag parquet path if needed.
    "hellaswag_train": os.path.join(MY_DATA_DIR, "HellaSwag/train-00000-of-00001.parquet"),
    "hellaswag_test":  os.path.join(MY_DATA_DIR, "HellaSwag/test-00000-of-00001.parquet"),

    "obqa_train": os.path.join(MY_DATA_DIR, "OBQA/main/train-00000-of-00001.parquet"),
    "obqa_test":  os.path.join(MY_DATA_DIR, "OBQA/main/test-00000-of-00001.parquet"),

    # xsum (csv dir)
    "xsum_dir": os.path.join(MY_DATA_DIR, "XSum/"),
}

class TokenizerWrapper:
    def __init__(self, input_ids):
        self.input_ids = input_ids

def get_tokenizer(model):
    if "llama" in model.lower():
        tokenizer = LlamaTokenizer.from_pretrained(model, use_fast=False)
        # fix for transformer 4.28.0.dev0 compatibility
        if tokenizer.bos_token_id != 1 or tokenizer.eos_token_id != 2:
            try:
                tokenizer.bos_token_id = 1
                tokenizer.eos_token_id = 2
            except AttributeError:
                pass
    else:
        tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
    return tokenizer


# =========================
# Existing loaders (wiki2/c4)
# =========================
def get_wikitext2(nsamples, seed, seqlen, tokenizer, batch_size):
    traindata = load_dataset(os.path.join(DATA_DIR, 'wikitext'), 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset(os.path.join(DATA_DIR, 'wikitext'), 'wikitext-2-raw-v1', split='test')

    trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))

    new_trainloader = []
    num_batches = nsamples // batch_size + (int)(nsamples % batch_size > 0)
    for i in range(0, num_batches):
        start = i * batch_size
        end = min(start + batch_size, nsamples)
        batched_inp = []
        batched_tar = []
        for j in range(start, end):
            batched_inp.append(trainloader[j][0])
            batched_tar.append(trainloader[j][1])
        batched_inp = torch.cat(batched_inp)
        batched_tar = torch.cat(batched_tar)
        new_trainloader.append((batched_inp, batched_tar))
    del trainloader
    trainloader = new_trainloader
    del new_trainloader

    return trainloader, testenc

def get_c4(nsamples, seed, seqlen, tokenizer, batch_size):
    traindata = load_dataset(
        os.path.join(DATA_DIR, 'allenai/c4'),
        data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
        split='train'
    )
    valdata = load_dataset(
        os.path.join(DATA_DIR, 'allenai/c4'),
        data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
        split='validation'
    )

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
            if trainenc.input_ids.shape[1] > seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))

    new_trainloader = []
    num_batches = nsamples // batch_size + (int)(nsamples % batch_size > 0)
    for i in range(0, num_batches):
        start = i * batch_size
        end = min(start + batch_size, nsamples)
        batched_inp = []
        batched_tar = []
        for j in range(start, end):
            batched_inp.append(trainloader[j][0])
            batched_tar.append(trainloader[j][1])
        batched_inp = torch.cat(batched_inp)
        batched_tar = torch.cat(batched_tar)
        new_trainloader.append((batched_inp, batched_tar))
    del trainloader
    trainloader = new_trainloader
    del new_trainloader

    valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
    valenc = valenc.input_ids[:, :(256 * seqlen)]
    valenc = TokenizerWrapper(valenc)

    return trainloader, valenc

def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None, batch_size=1):
    if 'wiki2' in name:
        return get_wikitext2(nsamples, seed, seqlen, tokenizer, batch_size)
    if 'c4' in name:
        return get_c4(nsamples, seed, seqlen, tokenizer, batch_size)


# =========================
# Existing "trainenc" style loaders
# =========================
def get_wikitext2_trainenc(seed, nsamples, tokenizer):
    traindata = load_dataset(os.path.join(DATA_DIR, 'wikitext'), 'wikitext-2-raw-v1', split='train')
    traindata = traindata.shuffle(seed=seed)
    trainenc = tokenizer("\n\n".join(traindata[:nsamples]['text']), return_tensors='pt')
    return trainenc

def get_c4_trainenc(seed, nsamples, tokenizer):
    traindata = load_dataset(
        os.path.join(DATA_DIR, 'allenai/c4'),
        data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
        split='train'
    )
    traindata = traindata.shuffle(seed=seed)
    trainenc = tokenizer(' '.join(traindata[:nsamples]['text']), return_tensors='pt')
    return trainenc

def get_pg19_bookcorpus_trainenc(seed, nsamples, tokenizer, dataset="pg19"):
    traindata = load_dataset(os.path.join(DATA_DIR, dataset), split='train', trust_remote_code=True)
    traindata = traindata.shuffle(seed=seed)
    trainenc = tokenizer("\n\n".join(traindata[:nsamples]['text']), return_tensors='pt')
    return trainenc

def get_alpaca_trainenc(seed, nsamples, tokenizer, seqlen=2048):
    traindata = load_dataset(os.path.join(DATA_DIR, "alpaca-cleaned"), split='train')
    traindata = traindata.shuffle(seed=seed)
    data = ["\n".join([i, j, k]) for i, j, k in zip(
        traindata[:nsamples]['instruction'],
        traindata[:nsamples]['input'],
        traindata[:nsamples]['output']
    )]
    trainenc = tokenizer(data, return_tensors='pt', max_length=seqlen, padding='max_length', truncation=True)
    trainenc["input_ids"] = trainenc["input_ids"].reshape(1, -1)
    trainenc["attention_mask"] = trainenc["attention_mask"].reshape(1, -1)
    return trainenc


# =========================
# MMLU (existing)
# =========================
def format_mmlu_example(dataset, include_answer=True):
    choices = ["A", "B", "C", "D"]
    data = []
    for ques, cho, ans in zip(dataset["question"], dataset["choices"], dataset["answer"]):
        prompt = ques
        for i in range(len(cho)):
            prompt += "\n{}. {}".format(choices[i], cho[i])
        prompt += "\nAnswer:"
        if include_answer:
            prompt += " {}. {}\n\n".format(choices[ans], choices[ans])
        data.append(prompt)
    return data

def get_mmlu_trainenc(seed, nsamples, tokenizer, seqlen=2048, num_tasks=None):
    from datasets import Dataset
    from tqdm import tqdm

    subclass = [
        'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge',
        'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics',
        'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics',
        'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic',
        'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
        'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics',
        'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics',
        'high_school_physics', 'high_school_psychology', 'high_school_statistics',
        'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality',
        'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management',
        'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios',
        'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law',
        'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies',
        'sociology', 'us_foreign_policy', 'virology', 'world_religions'
    ]
    if num_tasks is not None:
        random.shuffle(subclass)
        subclass = subclass[:num_tasks]

    keys = ["question", "subject", "choices", "answer"]
    subnum = nsamples // len(subclass)
    extra = nsamples % len(subclass)
    num_list = [subnum + 1] * extra + [subnum] * (len(subclass) - extra)
    random.shuffle(num_list)

    traindata = {key: [] for key in keys}
    for num, classsname in tqdm(zip(num_list, subclass), total=len(subclass), desc="Loading the subclass in MMLU"):
        try:
            data = load_dataset(os.path.join(DATA_DIR, "mmlu"), classsname, split="train").shuffle(seed=seed)[:num]
        except Exception:
            data = load_dataset(os.path.join(DATA_DIR, "mmlu"), classsname, split="validation").shuffle(seed=seed)[:num]
        for key in keys:
            traindata[key].extend(data[key])

    dataset = Dataset.from_dict(traindata)
    all_data = format_mmlu_example(dataset, include_answer=False)

    trainenc = tokenizer(all_data, return_tensors='pt', max_length=seqlen, padding='max_length', truncation=True)
    trainenc["input_ids"] = trainenc["input_ids"].reshape(1, -1)
    trainenc["attention_mask"] = trainenc["attention_mask"].reshape(1, -1)
    return trainenc


# ============================================================
# Migrated loaders from 【3】 -> "trainenc" style (for your 【1】)
# ============================================================

def _tokenize_as_trainenc(tokenizer, texts, seqlen):
    enc = tokenizer(texts, return_tensors='pt', max_length=seqlen, padding='max_length', truncation=True)
    enc["input_ids"] = enc["input_ids"].reshape(1, -1)
    enc["attention_mask"] = enc["attention_mask"].reshape(1, -1)
    return enc

def _maybe_subsample_hf(ds, nsamples, seed):
    if nsamples is None:
        return ds
    if nsamples >= len(ds):
        return ds
    ds = ds.shuffle(seed=seed)
    return ds.select(list(range(nsamples)))

# ---- ARC-Challenge ----
def _format_arc_prompt(question, choice_texts, choice_labels):
    lines = [f"{lab}. {txt}" for lab, txt in zip(choice_labels, choice_texts)]
    options_str = "\n".join(lines)
    return (
        "### Task:\nChoose the best answer to the following question.\n\n"
        f"### Question:\n{question}\n\n"
        f"### Options:\n{options_str}\n\n"
        "### Answer:"
    )

def get_arc_challenge_trainenc(seed, nsamples, tokenizer, seqlen=2048, split="train"):
    key = "arc_challenge_train" if split == "train" else "arc_challenge_test"
    data_path = DATA_PATHS[key]
    ds = load_dataset("parquet", data_files=data_path, split="train")
    ds = _maybe_subsample_hf(ds, nsamples, seed)

    texts = []
    for item in ds:
        question = item["question"]
        choice_texts = item["choices"]["text"]
        choice_labels = item["choices"]["label"]
        answer_key = str(item["answerKey"]).strip()
        prompt = _format_arc_prompt(question, choice_texts, choice_labels)
        texts.append(prompt + f" {answer_key}")
    return _tokenize_as_trainenc(tokenizer, texts, seqlen)

# ---- PIQA ----
def _format_piqa_prompt(goal, choices_2):
    return (
        "### Task:\nChoose the most physically plausible solution to achieve the goal.\n\n"
        f"### Goal:\n{goal}\n\n"
        "### Options:\n"
        f"A. {choices_2[0]}\n"
        f"B. {choices_2[1]}\n\n"
        "### Answer:"
    )

def get_piqa_trainenc(seed, nsamples, tokenizer, seqlen=2048, split="train"):
    key = "piqa_train" if split == "train" else "piqa_test"
    data_path = DATA_PATHS[key]
    ds = load_dataset("parquet", data_files=data_path, split="train")
    ds = _maybe_subsample_hf(ds, nsamples, seed)

    texts = []
    for item in ds:
        goal = item["question"]
        choices = item["choices"]
        if "answer_index" in item and item["answer_index"] is not None:
            ans_idx = int(item["answer_index"])
            answer_letter = "A" if ans_idx == 0 else "B"
        else:
            ans = str(item["answer"]).strip().upper()
            answer_letter = "A" if ans == "A" else "B"
        prompt = _format_piqa_prompt(goal, choices)
        texts.append(prompt + f" {answer_letter}")
    return _tokenize_as_trainenc(tokenizer, texts, seqlen)

# ---- Winogrande ----
def _format_winogrande_prompt(sentence, option1, option2):
    return (
        "### Task:\nChoose the correct option to fill in the blank (\"_\") in the sentence.\n\n"
        f"### Sentence:\n{sentence}\n\n"
        "### Options:\n"
        f"A. {option1}\n"
        f"B. {option2}\n\n"
        "### Answer:"
    )

def get_winogrande_trainenc(seed, nsamples, tokenizer, seqlen=2048, split="train"):
    key = "winogrande_train" if split == "train" else "winogrande_test"
    data_path = DATA_PATHS[key]
    ds = load_dataset("parquet", data_files=data_path, split="train")
    ds = _maybe_subsample_hf(ds, nsamples, seed)

    texts = []
    for item in ds:
        sentence = item["sentence"]
        option1 = item["option1"]
        option2 = item["option2"]
        ans_str = str(item["answer"]).strip()  # '1' or '2'
        ans_idx = int(ans_str) - 1
        answer_letter = "A" if ans_idx == 0 else "B"
        prompt = _format_winogrande_prompt(sentence, option1, option2)
        texts.append(prompt + f" {answer_letter}")
    return _tokenize_as_trainenc(tokenizer, texts, seqlen)

# ---- HellaSwag ----
def _format_hellaswag_prompt(ctx, endings):
    labels = ["A", "B", "C", "D"]
    option_lines = [f"{lab}. {txt}" for lab, txt in zip(labels, endings)]
    options_str = "\n".join(option_lines)
    return (
        "### Task:\nChoose the most plausible continuation of the following context.\n\n"
        f"### Context:\n{ctx}\n\n"
        f"### Options:\n{options_str}\n\n"
        "### Answer:"
    )

def get_hellaswag_trainenc(seed, nsamples, tokenizer, seqlen=2048, split="train"):
    key = "hellaswag_train" if split == "train" else "hellaswag_test"
    data_path = DATA_PATHS[key]
    ds = load_dataset("parquet", data_files=data_path, split="train")
    ds = _maybe_subsample_hf(ds, nsamples, seed)

    texts = []
    for item in ds:
        ctx = item["ctx"]
        endings = item["endings"]
        label_idx = int(str(item["label"]).strip())
        answer_letter = ["A", "B", "C", "D"][label_idx]
        prompt = _format_hellaswag_prompt(ctx, endings)
        texts.append(prompt + f" {answer_letter}")
    return _tokenize_as_trainenc(tokenizer, texts, seqlen)

# ---- XSum ----
def _read_xsum_csv(xsum_dir_or_file, split):
    import pandas as pd
    if os.path.isdir(xsum_dir_or_file):
        csv_file = os.path.join(xsum_dir_or_file, f"{split}.csv")
    else:
        csv_file = xsum_dir_or_file
    if not os.path.exists(csv_file):
        raise ValueError(f"XSum split file not found: {csv_file}")
    df = pd.read_csv(csv_file)
    return df

def _format_xsum_prompt(dialogue_text):
    return (
        "### Task:\nSummarize the following dialogue in one sentence.\n\n"
        f"### Dialogue:\n{dialogue_text}\n\n"
        "### Summary:\n"
    )

def get_xsum_trainenc(seed, nsamples, tokenizer, seqlen=2048, split="train"):
    xsum_dir = DATA_PATHS["xsum_dir"]
    df = _read_xsum_csv(xsum_dir, split)

    if "summary" not in df.columns:
        raise ValueError(f"XSum csv missing 'summary': columns={list(df.columns)}")

    if "dialogue" in df.columns:
        text_col = "dialogue"
    elif "document" in df.columns:
        text_col = "document"
    elif "article" in df.columns:
        text_col = "article"
    else:
        raise ValueError(f"XSum csv missing text col: columns={list(df.columns)}")

    if nsamples is not None and nsamples < len(df):
        rng = random.Random(seed)
        idx = list(range(len(df)))
        rng.shuffle(idx)
        df = df.iloc[idx[:nsamples]].reset_index(drop=True)

    texts = []
    for _, row in df.iterrows():
        src = str(row.get(text_col, ""))
        summ = str(row.get("summary", ""))
        texts.append(_format_xsum_prompt(src) + summ)

    return _tokenize_as_trainenc(tokenizer, texts, seqlen)

# ---- OpenBookQA (OBQA) ----
def _format_obqa_prompt(question_stem, choice_texts, choice_labels):
    lines = [f"{lab}. {txt}" for lab, txt in zip(choice_labels, choice_texts)]
    options_str = "\n".join(lines)
    return (
        "### Task:\nChoose the best answer to the following science question.\n\n"
        f"### Question:\n{question_stem}\n\n"
        f"### Options:\n{options_str}\n\n"
        "### Answer:"
    )

def get_obqa_trainenc(seed, nsamples, tokenizer, seqlen=2048, split="train"):
    key = "obqa_train" if split == "train" else "obqa_test"
    data_path = DATA_PATHS[key]
    ds = load_dataset("parquet", data_files=data_path, split="train")
    ds = _maybe_subsample_hf(ds, nsamples, seed)

    texts = []
    for item in ds:
        question = item["question_stem"]
        choice_texts = item["choices"]["text"]
        choice_labels = item["choices"]["label"]
        answer_key = str(item["answerKey"]).strip()
        prompt = _format_obqa_prompt(question, choice_texts, choice_labels)
        texts.append(prompt + f" {answer_key}")
    return _tokenize_as_trainenc(tokenizer, texts, seqlen)


# =========================
# Unified entry for your training calibration ("trainenc" object)
# =========================
def get_trainloaders(name, tokenizer, nsamples=128, seed=0, seqlen=2048, batch_size=1, num_tasks=None):
    name_l = name.lower()

    # original
    if 'wiki2' in name_l:
        return get_wikitext2_trainenc(seed, nsamples, tokenizer)
    if 'c4' in name_l:
        return get_c4_trainenc(seed, nsamples, tokenizer)
    if 'pg19' in name_l:
        return get_pg19_bookcorpus_trainenc(seed, nsamples, tokenizer, dataset="pg19")
    if 'bookcorpus' in name_l:
        return get_pg19_bookcorpus_trainenc(seed, nsamples, tokenizer, dataset="bookcorpus")
    if 'alpaca' in name_l:
        return get_alpaca_trainenc(seed, nsamples, tokenizer, seqlen=seqlen)
    if 'mmlu' in name_l:
        return get_mmlu_trainenc(seed, nsamples, tokenizer, seqlen=seqlen, num_tasks=num_tasks)

    # migrated (from 【3】)
    if 'arc-challenge' in name_l or 'arc_challenge' in name_l or (name_l == 'arc'):
        return get_arc_challenge_trainenc(seed, nsamples, tokenizer, seqlen=seqlen, split="train")
    if 'piqa' in name_l:
        return get_piqa_trainenc(seed, nsamples, tokenizer, seqlen=seqlen, split="train")
    if 'winogrande' in name_l or name_l.startswith('wino'):
        return get_winogrande_trainenc(seed, nsamples, tokenizer, seqlen=seqlen, split="train")
    if 'hellaswag' in name_l:
        return get_hellaswag_trainenc(seed, nsamples, tokenizer, seqlen=seqlen, split="train")
    if 'xsum' in name_l:
        return get_xsum_trainenc(seed, nsamples, tokenizer, seqlen=seqlen, split="train")
    if 'obqa' in name_l or 'openbookqa' in name_l:
        return get_obqa_trainenc(seed, nsamples, tokenizer, seqlen=seqlen, split="train")

    raise NameError(f"{name} is not implemented.")
