from typing import List, Dict, Set, Optional, Tuple
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.distributed import DistributedSampler


# Constants
IGNORE_INDEX = -100
SPECIAL = {
    "BOS": "$",
    "EOS": "#",
    "PAD": "<PAD>",
    "SEP": "|",
}


class ARLMGptDataset(Dataset):
    def __init__(
        self,
        input_strings: List[str],
        output_strings: List[str],
        stoi: Dict[str,int],
        itos: Dict[int,str],
        hard_cap: Optional[int] = None
    ):
        self.samples = fuse_to_autoregressive_strings(input_strings, output_strings)
        self.stoi, self.itos = stoi, itos
        self.hard_cap = hard_cap

        # special ids
        self.pad_id = self.stoi[SPECIAL["PAD"]]
        self.sep_id = self.stoi[SPECIAL["SEP"]]
        self.eos_id = self.stoi[SPECIAL["EOS"]]
        self.bos_id = self.stoi[SPECIAL["BOS"]]

        # sanity checks
        for s in self.samples[:200]:
            assert s.count(SPECIAL["SEP"]) == 1, f"Bad sample (SEP !=1): {s!r}"
            assert SPECIAL["EOS"] in s, f"Bad sample (missing EOS): {s!r}"

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        if self.hard_cap is not None and len(s) > self.hard_cap:
            s = s[: self.hard_cap - 1] + SPECIAL["EOS"]
        ids = torch.tensor([self.stoi[ch] for ch in s], dtype=torch.long)
        return {"ids": ids}
    

# Helper functions
def fuse_to_autoregressive_strings(inp: List[str], out: List[str]) -> List[str]:
    """Turn (input, output) into one sequence: $ INPUT | OUTPUT #"""
    return [SPECIAL["BOS"] + x + SPECIAL["SEP"] + y + SPECIAL["EOS"] for x, y in zip(inp, out)]

def build_vocab_from_strings(strings: List[str]) -> Tuple[Dict[str,int], Dict[int,str]]:
    chars: Set[str] = set()
    for s in strings:
        chars.update(s)
    chars.update(SPECIAL.values())
    stoi = {ch: i for i, ch in enumerate(sorted(chars))}
    itos = {i: ch for ch, i in stoi.items()}
    return stoi, itos

def load_split(name: str, split: str) -> Tuple[List[str], List[str]]:
    base = f"dataset/{name}/"
    split_map = {
        "train": ("input.txt", "target.txt"),
        "val0": ("input_val0.txt", "target_val0.txt"),
        "val1": ("input_val1.txt", "target_val1.txt"),
        "val2": ("input_val2.txt", "target_val2.txt"),
    }
    
    if split not in split_map:
        raise ValueError(f"Unknown split: {split}")
    
    in_file, out_file = split_map[split]
    with open(f"{base}/{in_file}") as f:
        inputs = [line.strip() for line in f]
    with open(f"{base}/{out_file}") as f:
        outputs = [line.strip() for line in f]
    return inputs, outputs



def collate_autoreg(batch, pad_id: int, sep_id: int, eos_id: int):
    seqs = [b["ids"] for b in batch]
    fixed = []
    for s in seqs:
        if len(s) == 0 or s[-1].item() != eos_id:
            s = torch.cat([s, torch.tensor([eos_id], dtype=s.dtype)])
        fixed.append(s)

    input_ids = pad_sequence(fixed, batch_first=True, padding_value=pad_id)
    attention_mask = (input_ids != pad_id).long()
    labels = input_ids.clone().fill_(IGNORE_INDEX)

    B, T = input_ids.size()
    # supervise only OUTPUT region: after SEP up to EOS (inclusive), excluding PAD
    for b in range(B):
        row = input_ids[b]
        sep_where = (row == sep_id).nonzero(as_tuple=True)[0]
        if sep_where.numel() == 0:
            continue
        start = sep_where[0].item() + 1
        stop = T
        eos_where = (row == eos_id).nonzero(as_tuple=True)[0]
        pad_where = (row == pad_id).nonzero(as_tuple=True)[0]
        if eos_where.numel() > 0:
            stop = min(stop, eos_where[0].item() + 1)
        if pad_where.numel() > 0:
            stop = min(stop, pad_where[0].item())
        if start < stop:
            labels[b, start:stop] = row[start:stop]

    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}


# ----------------------
# Dataloaders
# ----------------------
def create_dataloaders(batch_size: int, model: str, ddp=False, local_rank=0):
    tr_inp, tr_out = load_split(model, "train")
    val0_inp, val0_out = load_split(model, "val0")
    val1_inp, val1_out = load_split(model, "val1")
    val2_inp, val2_out = load_split(model, "val2")

    train_fused = fuse_to_autoregressive_strings(tr_inp, tr_out)
    stoi, itos = build_vocab_from_strings(train_fused)

    train_ds = ARLMGptDataset(tr_inp, tr_out, stoi, itos)
    val0_ds = ARLMGptDataset(val0_inp, val0_out, stoi, itos)
    val1_ds = ARLMGptDataset(val1_inp, val1_out, stoi, itos)
    val2_ds = ARLMGptDataset(val2_inp, val2_out, stoi, itos)

    # Create collate functions without lambda to avoid pickling issues
    def make_collate_fn(pad_id, sep_id, eos_id):
        def collate_fn(batch):
            return collate_autoreg(batch, pad_id, sep_id, eos_id)
        return collate_fn

    train_collate_fn = make_collate_fn(train_ds.pad_id, train_ds.sep_id, train_ds.eos_id)
    val0_collate_fn = make_collate_fn(val0_ds.pad_id, val0_ds.sep_id, val0_ds.eos_id)
    val1_collate_fn = make_collate_fn(val1_ds.pad_id, val1_ds.sep_id, val1_ds.eos_id)
    val2_collate_fn = make_collate_fn(val2_ds.pad_id, val2_ds.sep_id, val2_ds.eos_id)

    if ddp:
        train_sampler = DistributedSampler(train_ds, shuffle=True)
        train_dl = DataLoader(train_ds, batch_size=batch_size, sampler=train_sampler,
                              collate_fn=train_collate_fn, num_workers=0, pin_memory=True)
    else:
        train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              collate_fn=train_collate_fn, num_workers=0, pin_memory=True)

    val_dls = [
        DataLoader(val0_ds, batch_size=batch_size, shuffle=False, 
                   collate_fn=val0_collate_fn, num_workers=0, pin_memory=True),
        DataLoader(val1_ds, batch_size=batch_size, shuffle=False, 
                   collate_fn=val1_collate_fn, num_workers=0, pin_memory=True),
        DataLoader(val2_ds, batch_size=batch_size, shuffle=False, 
                   collate_fn=val2_collate_fn, num_workers=0, pin_memory=True)
    ]
    
    return train_dl, val_dls, train_ds, (val0_ds, val1_ds, val2_ds)
