"""
Alpaca training dataloaders

We adopt the original prompt template; goes something like:
```
Below is an instruction that describes a task. 
Write a response that appropriately completes the request.
### Instruction:
{instruction}
 
### Response:
{response}
```
See `PROMPT_DICT` for more. 
"""
from functools import partial
from os.path import join

from datasets import load_dataset, Dataset, DatasetDict, concatenate_datasets
import evaluate
from .utils import (
    get_lm_loader, get_seq2seq_loader,
    convert_to_hf_dataset, 
    get_tokenizer_from_config,
    download_scrolls_metric as download_metric
)
from .utils.packing import ConcatDataset
from typing import Dict, Optional, Union
import re
HFData = Union[Dataset, DatasetDict]

ENGLISH_TASKS = [
    "hotpotqa",          # multi-doc QA
    "2wikimqa",          # multi-hop QA
    "musique",           # multi-doc QA
    "narrativeqa",       # long narrative QA
    "qasper",            # scientific paper QA
    "multifieldqa_en",   # varied single-doc QA
    "gov_report",        # summarization
    "qmsum",             # meeting summarization
    "multi_news",        # multi-article summarization
    "triviaqa",          # few-shot QA
    "nq",                # few-shot QA
    "trec",              # few-shot classification
    "passage_retrieval_en"  # synthetic retrieval
]


def load_all_longcontext_data_extended(cache_dir) -> Dict[str, HFData]:
    data: Dict[str, HFData] = {}

    # # --- Multi-hop / long QA ---
    data["hotpotqa"]   = load_dataset("hotpotqa/hotpot_qa", "distractor", split="train[:6%]", cache_dir=cache_dir)
    data["musique"]    = load_dataset("bdsaglam/musique", split="train[:16%]", cache_dir=cache_dir)
    data["narrativeqa"]= load_dataset("deepmind/narrativeqa", split="train[:16%]", cache_dir=cache_dir)
    data["qasper"]     = load_dataset("allenai/qasper", split="train", cache_dir=cache_dir)
    data["govreport"]  = load_dataset("ccdv/govreport-summarization", split="train[:30%]", cache_dir=cache_dir)
    data["qmsum"]      = load_dataset("pszemraj/qmsum-cleaned", split="train", cache_dir=cache_dir)
    data["multinews"]  = load_dataset("alexfabbri/multi_news", split="train[:10%]", cache_dir=cache_dir)
    data["trivia_qa"]  = load_dataset("mandarjoshi/trivia_qa", 'rc', split="train[:6%]", cache_dir=cache_dir)
    data["natural_questions"]  = load_dataset("sentence-transformers/natural-questions", split="train[:6%]", cache_dir=cache_dir)
    data["longalpaca"]  = load_dataset("Yukang/LongAlpaca-12k", split="train[:60%]", cache_dir=cache_dir)

    # # --- Long summarization / meetings / narrative ---
    # if booksum_config is None:
    #     data["booksum"] = load_dataset("kmfoda/booksum", split=booksum_split)
    # else:
    #     data["booksum"] = load_dataset("kmfoda/booksum", booksum_config, split=booksum_split)

    # # --- Scientific/Biomed summarization ---
    # data["arxiv_sum"]  = load_dataset("ccdv/arxiv-summarization", split=arxiv_split)
    # data["pubmed_sum"] = load_dataset("ccdv/pubmed-summarization", split=pubmed_split)

    # # --- SCROLLS extra tasks ---
    # data["summscreenfd"] = load_dataset("tau/scrolls", "summ_screen_fd", split=summscreenfd_split)
    # data["multidoc2dial"]= load_dataset("multidoc2dial", split=multidoc2dial_split)

    # --- Coding ---
    # data["longcodebench"] = load_dataset("Steefano/LCB", split=longcodebench_split)
    # data["bigcodebench"]  = load_dataset("bigcode/bigcodebench", split=bigcodebench_split)
    # data["livecodebench"] = load_dataset(f"livecodebench/{livecodebench_variant}", split=livecodebench_split)
    # data["repobench_p"]   = load_dataset("tianyang/repobench-p",'python' split=repobench_split)
    # data["swebench_lite"] = load_dataset("princeton-nlp/SWE-bench_Lite", split=sweb_split)

    # --- Legal ---
    # data["contractnli"]   = load_dataset("lex_glue", "contract_nli", split=contractnli_split)
    # data["ledgar"]        = load_dataset("lex_glue", "ledgar", split=ledgar_split)

    # --- RULER ---
    # if ruler_split is not None:
    # data["ruler"] = load_dataset("rbiswasfc/ruler",'qa_2_8k', split=ruler_split)
    # Official generator: https://github.com/NVIDIA/RULER

    # # --- Multilingual (optional) ---
    # if loogle_split is not None:
    #     data["loogle"] = load_dataset("THUDM/LooGLE", split=loogle_split)

    return data

def group_and_filter(batch, chunk_size: int):
    keys = [k for k in ("input_ids", "labels", "attention_mask") if k in batch]
    if not keys:
        return {}

    # 1) concatenate within the batch (robust + faster than sum(..., []))
    concat = {k: _flatten(batch[k]) for k in keys}

    # 2) chunk
    total = (len(concat["input_ids"]) // chunk_size) * chunk_size
    if total == 0:
        return {k: [] for k in keys}

    res_np = {}
    for k in keys:
        arr = np.asarray(concat[k][:total], dtype=np.int32).reshape(-1, chunk_size)
        res_np[k] = arr

    # 3) filter chunks with all -100 labels (if labels exist)
    if "labels" in res_np:
        keep = (res_np["labels"] != -100).any(axis=1)
    else:
        keep = np.ones(res_np[keys[0]].shape[0], dtype=bool)

    out = {k: res_np[k][keep].tolist() for k in keys}
    return out

def _toklen(s,tok): 
    return len(tok(s, add_special_tokens=False)["input_ids"])

def _pack_docs(docs, budget, tokenizer):
    out, cur = [], 0
    for d in docs:
        L = _toklen(d, tokenizer)
        if cur + L > budget and cur > 0: break
        out.append(d); cur += L
    return out if out else (docs[:1] if docs else [])


def _make_prompt(question: str, docs: list[str]) -> str:
    ctx = "\n\n".join(f"{d}" for i, d in enumerate(docs))
    return f"Answer with a short phrase.\n\n{ctx}\n\nQuestion: {question}\nAnswer:"

def map_hotpotqa(ex):
    # Build doc texts from context: [[title, sent1, sent2, ...], ...]
    docs, titles = [], []
    art = ex.get("context", [])
    for i in range(len(art['title'])):
        title = str(art['title'][i])
        text = "".join(art["sentences"][i])
        docs.append(f"{title}: {text}")
    prompt = _make_prompt(ex.get("question", ""), docs)
    answer = str(ex.get("answer", "")).strip()
    return {"prompt": prompt, "answer": answer}

def map_musique(ex, tokenizer, max_len):
    q = ex.get("question") or ex.get("Question")
    a = ex.get('answer')
    def _toklen(s: str) -> int:
        return len(tokenizer(s, add_special_tokens=False)["input_ids"])
    
    answer = str(ex.get("answer", "")).strip()
    alen = _toklen(answer)
    max_len = max_len - alen

    paras = ex.get("paragraphs")
    docs = []
    for p in paras:
        docs.append(str(p.get("paragraph_text") if isinstance(p, dict) else p))
    docs = " ".join(docs)
    sents = re.split(r'(?<=[.!?])\s+', docs)

    
    
    # pack document within max_len
    packed, cur = [], 0
    for s in sents:
        L = _toklen(s)
        if cur + L > max_len and cur > 0:
            break
        packed.append(s)
        cur += L
    doc_text = " ".join(packed) if packed else (sents[0] if sents else "")
    
    prompt = _make_prompt(q, [doc_text])
    
    return {"prompt": prompt, "answer": answer}

def map_narrativeqa(ex, tokenizer, max_len):
    q = ex.get("question")['text']
    a = ex.get("answers")
    def _toklen(s: str) -> int:
        return len(tokenizer(s, add_special_tokens=False)["input_ids"])
    answer = str(a).strip()
    alen = _toklen(answer)
    max_len = max_len - alen
    
    if isinstance(a, list): a = a[0]['text'] if a else ""
    context = ex.get("document", {}).get("summary")['text']
    sents = re.split(r'(?<=[.!?])\s+', context)

    
    # pack document within max_len
    packed, cur = [], 0
    for s in sents:
        L = _toklen(s)
        if cur + L > max_len and cur > 0:
            break
        packed.append(s)
        cur += L
    doc_text = " ".join(packed) if packed else (sents[0] if sents else "")

    prompt = _make_prompt(q, [doc_text])
    return {"prompt": prompt, "answer": answer}


def map_qasper(ex, tokenizer, max_len):
    q = ex.get("qas")['question'][0]
    
    a = ex.get("qas")['answers'][0]['answer'][0]['free_form_answer']
    answer = str(a).strip()

    def _toklen(s: str) -> int:
        return len(tokenizer(s, add_special_tokens=False)["input_ids"])

    alen = _toklen(answer)
    max_len = max_len - alen
    
    parts = []
    abs_text = ex.get("abstract")
    if isinstance(abs_text, list): parts.append(" ".join(abs_text))
    elif abs_text: parts.append(str(abs_text))
    paragraphs = ex.get("full_text")['paragraphs']
    for p in paragraphs:
        if len(p) == 0:
            continue
        parts.append("".join(p))
    if not q or not parts: return {"prompt": None, "answer": None}
    docs = _pack_docs(parts, max_len, tokenizer)

    docs = " ".join(docs)
    sents = re.split(r'(?<=[.!?])\s+', docs)

    
    # pack document within max_len
    packed, cur = [], 0
    for s in sents:
        L = _toklen(s)
        if cur + L > max_len and cur > 0:
            break
        packed.append(s)
        cur += L
    doc_text = " ".join(packed) if packed else (sents[0] if sents else "")


    prompt = _make_prompt(q, [doc_text])
    
    return {"prompt": prompt, "answer": answer}

def map_govreport(ex, tokenizer, max_len):
    src = ex.get("document") or ex.get("report") or ex.get("source") or ""
    tgt = ex.get("summary") or ex.get("target") or ""
    
    def _toklen(s: str) -> int:
        return len(tokenizer(s, add_special_tokens=False)["input_ids"])
    # Short-answer “headline” derived from the gold summary (first clause)
    first_clause = re.split(r"[.\n]|—|-{2,}|:|\u2014|\u2013", str(tgt).strip(), maxsplit=1)[0]
    # light cleanup
    headline = first_clause.strip().strip('"').strip("'")
    # If it's empty/fallback, take up to ~12 words of the summary
    if not headline:
        headline = " ".join(str(tgt).strip().split()[:12])
    answer = headline
    alen = _toklen(answer)
    max_len = max_len - alen
    
    
    # create sentences
    sents = re.split(r"(?<=[.!?])\s+", str(src).strip())
    
    # pack document within max_len
    packed, cur = [], 0
    for s in sents if len(sents) > 1 else [str(src)]:
        L = _toklen(s)
        if cur + L > max_len and cur > 0:
            break
        packed.append(s)
        cur += L
    doc_text = " ".join(packed) if packed else (sents[0] if sents else "")

    prompt = (
            "Answer with a short title (a few words) that best describes the report.\n\n"
            f"{doc_text}\n\nQuestion: What is a concise title for this report?\nAnswer:"
        )
    return {"prompt": prompt, "answer": answer}

def map_qmsum(ex, tokenizer, max_len):
    src = ex.get("input") or ""
    tgt = ex.get("output") or ""
    def _toklen(s: str) -> int:
        return len(tokenizer(s, add_special_tokens=False)["input_ids"])
    # Short-answer “headline” derived from the gold summary (first clause)
    first_clause = re.split(r"[.\n]|—|-{2,}|:|\u2014|\u2013", str(tgt).strip(), maxsplit=1)[0]
    # light cleanup
    headline = first_clause.strip().strip('"').strip("'")
    # If it's empty/fallback, take up to ~12 words of the summary
    if not headline:
        headline = " ".join(str(tgt).strip().split()[:12])
    answer = headline
    alen = _toklen(answer)
    max_len = max_len - alen
    
    # create sentences
    sents = re.split(r"(?<=[.!?])\s+", str(src).strip())
    
    # pack document within max_len
    packed, cur = [], 0
    for s in sents if len(sents) > 1 else [str(src)]:
        L = _toklen(s)
        if cur + L > max_len and cur > 0:
            break
        packed.append(s)
        cur += L
    doc_text = " ".join(packed) if packed else (sents[0] if sents else "")

    prompt = (
            "Answer with a short title (a few words) that best describes the meeting content.\n\n"
            f"{doc_text}\n\nQuestion: What is a concise title for this report?\nAnswer:"
        )
    return {"prompt": prompt, "answer": answer}

def map_multinews(ex, tokenizer, max_len):
    max_summary_sents = 6
    src = ex.get("document") or ""
    tgt = ex.get("summary") or ""
    def _toklen(s: str) -> int:
        return len(tokenizer(s, add_special_tokens=False)["input_ids"])
    tgt_sents = re.split(r"(?<=[.!?])\s+", str(tgt).strip())
    answer = " ".join(tgt_sents[:max_summary_sents]) if tgt_sents else str(tgt).strip()
    alen = _toklen(answer)
    max_len = max_len - alen
    
    # create sentences
    sents = re.split(r"(?<=[.!?])\s+", str(src).strip())
    
    # pack document within max_len
    packed, cur = [], 0
    for s in sents if len(sents) > 1 else [str(src)]:
        L = _toklen(s)
        if cur + L > max_len and cur > 0:
            break
        packed.append(s)
        cur += L
    doc_text = " ".join(packed) if packed else (sents[0] if sents else "")

    prompt = (
            f"You are a concise summarizer. Write no more than {max_summary_sents} sentences. "
            f"\n\n{doc_text}\n\nSummary:"
        )
    
    return {"prompt": prompt, "answer": answer}

def map_naturalques(ex):
    q = ex.get("query")
    a = ex.get("answer")
    prompt = _make_prompt(q, [])
    answer = str(a).strip()
    return {"prompt": prompt, "answer": answer}


def map_longalpaca(ex, tokenizer, max_len):
    instr = ex.get("instruction") or ""
    out = ex.get("output") or ""
    answer = str(out).strip()
    def _toklen(s: str) -> int:
        return len(tokenizer(s, add_special_tokens=False)["input_ids"])
    alen = _toklen(answer)
    max_len = max_len - alen
    
    # create sentences
    sents = re.split(r"(?<=[.!?])\s+", str(instr).strip())
    
    # pack document within max_len
    packed, cur = [], 0
    for s in sents if len(sents) > 1 else [str(instr)]:
        L = _toklen(s)
        if cur + L > max_len and cur > 0:
            break
        packed.append(s)
        cur += L
    prompt = " ".join(packed) if packed else (sents[0] if sents else "")
    return {"prompt": prompt, "answer": answer}

def map_triviaqa(ex, tokenizer, max_len):
    q = str(ex.get("question") or "").strip()
    if not q:
        return {"prompt": None, "answer": None}
    
    ans_text = ""
    if isinstance(ex.get("answer"), dict):
        # HF TriviaQA 'answer' often has 'value' key
        ans_text = str(ex["answer"].get("value", "")).strip()
    else:
        ans_text = str(ex.get("answer", "")).strip()

    if not ans_text:
        return {"prompt": None, "answer": None}
    
    def _toklen(s: str) -> int:
        return len(tokenizer(s, add_special_tokens=False)["input_ids"])
    
    alen = _toklen(ans_text)
    max_len = max_len - alen

    ctx_chunks = []
    if "entity_pages" in ex and ex["entity_pages"]:
        pg = ex["entity_pages"]
        for i in range(len(pg["title"])):
            title = pg.get("title", "")[i]
            text = pg.get("wiki_context", "")[i]
            if title or text:
                ctx_chunks.append(f"{title}\n{text}".strip())
    if "search_results" in ex and ex["search_results"]:
        sr = ex["search_results"]
        for i in range(len(sr["title"])):
            title = sr.get("title", "")[i]
            text = sr.get("search_context", "")[i]
            if title or text:
                ctx_chunks.append(f"{title}\n{text}".strip())
    # If no context found, skip
    if not ctx_chunks:
        return {"prompt": None, "answer": None}

    # create sentences
    ctx = " ".join(ctx_chunks)
    sents = re.split(r"(?<=[.!?])\s+", str(ctx).strip())
    
    # Greedy pack contexts up to token budget
    packed, cur = [], 0
    for chunk in sents:
        L = _toklen(chunk)
        if cur + L > max_len and cur > 0:
            break
        packed.append(chunk)
        cur += L

    ctx = "\n\n".join(f"{c}" for i, c in enumerate(packed))
    
    # Compose prompt
    prompt = (
        f"{ctx}\n\nQuestion: {q}\nAnswer:"
    )
    return {"prompt": prompt, "answer": ans_text}

def load_data(name: str, dataset_config: dict, pretrained_model_config: dict,
              preprocess_config: dict, **loader_kwargs: any):
    """
    Shared function to load dataset from experiment config
    -> e.g., see configs/experiments/distill_alpaca_clean_lr1e-2.yaml
    """
    # Misc. setup
    cache_dir = dataset_config['cache_dir']
    input_len = dataset_config['chunk_size']
    concat_data = dataset_config['concat_data']

    tokenizer_name = pretrained_model_config['pretrained_model_name_or_path']
    tokenizer_name = tokenizer_name.split('/')[-1]
    # save_path = join(cache_dir, f'{name}_{tokenizer_name}')
    
    # Setup tokenizer
    tokenizer = get_tokenizer_from_config(pretrained_model_config)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}')

    tokenizer.padding_side = 'left'  # for decoder-only generation
    # Get initial data
    ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs', 'name']
    # dataset = load_dataset(
    #     **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}
    # )
    dataset = load_all_longcontext_data_extended(cache_dir = cache_dir)
    dataset['hotpotqa'] = dataset['hotpotqa'].map(map_hotpotqa, remove_columns=dataset['hotpotqa'].column_names)
    dataset['musique'] = dataset['musique'].map(map_musique, remove_columns=dataset['musique'].column_names, fn_kwargs={'tokenizer': tokenizer, "max_len": input_len})
    dataset['narrativeqa'] = dataset['narrativeqa'].map(map_narrativeqa, remove_columns=dataset['narrativeqa'].column_names,  fn_kwargs={'tokenizer': tokenizer, "max_len": input_len})
    dataset['qasper'] = dataset['qasper'].map(map_qasper, remove_columns=dataset['qasper'].column_names,  fn_kwargs={'tokenizer': tokenizer, "max_len": input_len}).filter(lambda ex: ex["prompt"] is not None)
    dataset['govreport'] = dataset['govreport'].map(map_govreport, remove_columns=dataset['govreport'].column_names, fn_kwargs={'tokenizer': tokenizer, "max_len": input_len})
    dataset['qmsum'] = dataset['qmsum'].map(map_qmsum, remove_columns=dataset['qmsum'].column_names, fn_kwargs={'tokenizer': tokenizer, "max_len": input_len})
    dataset['multinews'] = dataset['multinews'].map(map_multinews, remove_columns=dataset['multinews'].column_names, fn_kwargs={'tokenizer': tokenizer, "max_len": input_len})
    dataset['natural_questions'] = dataset['natural_questions'].map(map_naturalques, remove_columns=dataset['natural_questions'].column_names)
    dataset['longalpaca'] = dataset['longalpaca'].map(map_longalpaca, remove_columns=dataset['longalpaca'].column_names, fn_kwargs={'tokenizer': tokenizer, "max_len": input_len})
    dataset['trivia_qa'] = dataset['trivia_qa'].map(map_triviaqa, remove_columns=dataset['trivia_qa'].column_names, fn_kwargs={'tokenizer': tokenizer, "max_len": input_len}).filter(lambda ex: ex["prompt"] is not None)
    dataset = concatenate_datasets(list(dataset.values())).shuffle(seed=42)
    # dataset = dataset['trivia_qa']
    train_set = convert_to_hf_dataset([dataset[ix] for ix in range(200, len(dataset))], cache_dir)
    val_set   = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir)
    test_set  = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir)

   
        
    # Convert to dicts of {input_ids, attention_mask, labels}
    train_set = train_set.map(
        partial(tokenize, tokenizer=tokenizer, include_label=True), 
        remove_columns=list(dataset.features)) #  load_from_cache_file=False)
    val_set = val_set.map(
        partial(tokenize, tokenizer=tokenizer, include_label=True),
        remove_columns=list(dataset.features)) #  load_from_cache_file=False)
    test_set  = test_set.map(
        partial(tokenize, tokenizer=tokenizer, include_label=False),
        remove_columns=list(dataset.features)) #  load_from_cache_file=False)
    # Chunk together train and val sets
    if concat_data:
        train_set = ConcatDataset(train_set, chunk_size=input_len)
        val_set = ConcatDataset(val_set, chunk_size=input_len)
        print("here")


    # max_seq = max(map(len, train_set["input_ids"]))
    # print(max_seq)
    # seqs = train_set["input_ids"]
    # lengths = [len(seq) for seq in seqs]
    # print(sum(lengths) / len(lengths))
    
    # Get dataloaders
    dataloaders = {
        'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs),
        'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs),
        'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs),
    }

    
    # Evaluation metric
    try:
        metric = evaluate.load(download_metric(), 'gov_report')  # hack but we want rouge
    except Exception as e:
        print(f'Error loading metric: {e}')
        metric = None

    # Finishing touches
    for k, v in dataloaders.items():  # Make tokenizer accessible
        dataloaders[k].dataset.tokenizer = tokenizer
        dataloaders[k].dataset.metric = metric
    return dataloaders


def tokenize(sample, tokenizer, include_label: bool = True):
    """
    tokenize dataset
    """
    full_text = sample["prompt"] + sample["answer"]
    if include_label:
        answer = tokenizer.encode(f'{full_text}{tokenizer.eos_token}', 
                                  add_special_tokens=False)
        target = None
    else:
        answer = []
        target = tokenizer.encode(f'{full_text}{tokenizer.eos_token}', 
                                  add_special_tokens=False)
    input_ids =  answer

    # Pad with tokenizer.pad_token_id to make sequence length a multiple of 128.
    seq_len = len(input_ids)
    target_len = ((seq_len + 127) // 128) * 128     # Compute nearest multiple of 128 >= current length
    pad_len = target_len - seq_len
    input_ids += [tokenizer.pad_token_id] * pad_len

    attn_mask = [1] * len(input_ids)
    sample =  {
        "input_ids": input_ids,
        "attention_mask" : attn_mask,
        "labels": answer if include_label else target,
    }
    return sample
