import torch
from torch.utils.data import Dataset
import numpy as np
import os, re, pickle
from sklearn.datasets import fetch_20newsgroups


def regroup_dataset(labels):
    """
    Regroups the original 20 newsgroups labels into 7 broader categories.

    Mapping:
    0 -> alt.atheism
    1 -> comp.*
    2 -> misc.forsale
    3 -> rec.*
    4 -> sci.*
    5 -> soc.religion.christian
    6 -> talk.*
    """
    regrouped = labels.copy()
    for idx, val in enumerate(labels):
        if val == 0:
            regrouped[idx] = 0
        elif val in [1, 2, 3, 4, 5]:
            regrouped[idx] = 1
        elif val == 6:
            regrouped[idx] = 2
        elif val in [7, 8, 9, 10]:
            regrouped[idx] = 3
        elif val in [11, 12, 13, 14]:
            regrouped[idx] = 4
        elif val == 15:
            regrouped[idx] = 5
        elif val in [16, 17, 18, 19]:
            regrouped[idx] = 6
    print('Labels regrouped into 7 categories. Shape:', regrouped.shape)
    return regrouped
  


def load_glove_embeddings(file_path, vocab_limit=None):
    vocab = {"<UNK>": 0}
    vectors = [np.random.normal(scale=0.6, size=300).tolist()]  # UNK vector
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for idx, line in enumerate(f):
            if vocab_limit and idx >= vocab_limit:
                break
            parts = line.rstrip().split(' ')
            word = parts[0]
            vector = list(map(float, parts[1:]))
            vocab[word] = len(vectors)
            vectors.append(vector)

    embedding_tensor = torch.tensor(vectors, dtype=torch.float)
    return vocab, embedding_tensor


def build_news_pickle(glove_txt: str = "data/glove.6B/glove.6B.300d.txt",
                      out_path: str = "data/20news-bydate/news.pkl",
                      max_len: int = 1000):
    """Builds a pre-tokenized 20 Newsgroups pickle for the NEWS experiment.

    The resulting pickle contains a tuple `(embedding_weights, data_ids, labels)` where:
      - embedding_weights: FloatTensor [vocab_size, emb_dim]
      - data_ids: np.ndarray [N, max_len] (token IDs)
      - labels: np.ndarray [N] regrouped into 7 superclasses
    """
    os.makedirs(os.path.dirname(out_path), exist_ok=True)

    # 1) Load GloVe embeddings and vocab
    vocab, emb = load_glove_embeddings(glove_txt)

    # 2) Fetch 20 Newsgroups (strip headers/footers/quotes)
    ng = fetch_20newsgroups(subset="all", remove=("headers", "footers", "quotes"))
    texts = ng.data
    labels = np.array(ng.target, dtype=np.int64)
    labels = regroup_dataset(labels)

    # 3) Simple regex tokenizer mapped to vocab IDs (UNK=0)
    def to_ids(s: str):
        toks = re.findall(r"\b\w+\b", s.lower())
        ids = [vocab.get(t, 0) for t in toks[:max_len]]
        if len(ids) < max_len:
            ids += [0] * (max_len - len(ids))
        return ids

    ids_mat = np.stack([to_ids(t) for t in texts]).astype(np.int64)

    # 4) Save
    with open(out_path, "wb") as f:
        pickle.dump((emb.numpy(), ids_mat, labels), f)
    print(f"Saved {out_path} (N={ids_mat.shape[0]}, max_len={max_len}, vocab={len(vocab)})")


class NewsDataset(Dataset):
    def __init__(self, dataset_type, data, labels, transform=None, max_length=50):

        self.dataset_type = dataset_type
        self.data = data
        self.labels = labels
        self.targets = self.labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # indices = self.data[idx]
        # padded = indices[:self.max_length] + [self.vocab["<UNK>"]] * max(0, self.max_length - len(indices))
        # return torch.tensor(padded, dtype=torch.long), self.labels[idx]
        return self.data[idx], self.targets[idx]
        

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Utilities for the NEWS (20 Newsgroups) dataset.")
    parser.add_argument("--build", action="store_true", help="Build the news.pkl file for the NEWS experiment.")
    parser.add_argument("--glove", type=str, default="data/glove.6B/glove.6B.300d.txt", help="Path to GloVe .txt embeddings (e.g., glove.6B.300d.txt)")
    parser.add_argument("--out", type=str, default="data/20news-bydate/news.pkl", help="Output pickle path.")
    parser.add_argument("--max_len", type=int, default=1000, help="Max tokens per document.")
    args = parser.parse_args()

    if args.build:
        build_news_pickle(glove_txt=args.glove, out_path=args.out, max_len=args.max_len)
    else:
        parser.print_help()
