import os
import torch 
import itertools
import numpy as np
import collections
import pandas as pd
from typing import List, Tuple
import jax.numpy as jnp
import jax 
from torch.utils.data import DataLoader
# -----------------------------
# Data Processing
# -----------------------------
vocab_size = 15000
max_seq_len = 100
PAD = 0
UNK = 1

def pad_sequences(sequences: List[np.ndarray], pad_value: int = 0) -> np.ndarray:
    max_len = max(len(seq) for seq in sequences)
    return np.stack([
        np.pad(seq, (0, max_len - len(seq)), constant_values=pad_value)
        for seq in sequences
    ])

class AGNewsDataset:
    def __init__(self, features: List[np.ndarray], labels: np.ndarray):
        self.features = features
        self.labels = labels
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

def numpy_collate(batch: List[Tuple[np.ndarray, int]], pad_value=0):
    texts, labels = zip(*batch)
    padded_texts = pad_sequences(list(texts), pad_value=pad_value)
    return jnp.array(padded_texts), jnp.array(labels)

def create_data_loader(dataset, batch_size, pad_value=0, shuffle=True):
    indices = np.arange(len(dataset))
    if shuffle:
        np.random.shuffle(indices)
    for start in range(0, len(dataset), batch_size):
        batch_indices = indices[start:start+batch_size]
        batch = [dataset[i] for i in batch_indices]
        yield numpy_collate(batch, pad_value=pad_value)
def load_agnews_data(data_path, vocab_size=15000, max_seq_len=100,is_torch = False):
    train_data = pd.read_csv(os.path.join(data_path,"train.csv"))
    test_data = pd.read_csv(os.path.join(data_path,"test.csv"))

    y_train = np.array(train_data['Class Index'].values - 1)
    y_test = np.array(test_data['Class Index'].values - 1)

    tokenizer = lambda x: x.lower().replace("\\", " ").split()
    train_texts = [tokenizer(text)[:max_seq_len] for text in train_data['Description']]
    test_texts = [tokenizer(text)[:max_seq_len] for text in test_data['Description']]

    counter = collections.Counter(itertools.chain.from_iterable(train_texts))
    most_common_words = [w for w, _ in counter.most_common(vocab_size - 2)]
    PAD, UNK = 0, 1
    word_to_id = {w: i + 2 for i, w in enumerate(most_common_words)}

    def encode(text: List[str]) -> np.ndarray:
        if (is_torch):
            return torch.tensor([word_to_id.get(word, UNK) for word in text])
        else:
            return np.array([word_to_id.get(word, UNK) for word in text], dtype=np.int32)

    x_train = [encode(text) for text in train_texts]
    x_val = [encode(text) for text in test_texts]

    return (x_train, y_train), (x_val, y_test), len(set(y_train)), PAD
def flax_make_numpy_loader(X, y, batch_size, pad_id=0):
    def mask_fn(x):
        return (x != pad_id)[:, None, None, :]
    for i in range(0, len(X), batch_size):
        x_batch = np.array(pad_sequences(X[i:i + batch_size], pad_value=pad_id), dtype=np.int32)
        y_batch = np.array(y[i:i + batch_size])
        pad_mask = mask_fn(x_batch)

        yield {
            'input': jax.device_put(jnp.array(x_batch)),
            'label': jax.device_put(jnp.array(y_batch)),
            'mask': jax.device_put(jnp.array(pad_mask))
        }
def collate_fn(batch):
    texts  = [text for text, label in batch]
    labels = torch.tensor([label for text, label in batch])
    texts_padded = torch.nn.utils.rnn.pad_sequence(texts,batch_first=True, padding_value = 0) #PAD is equal to 0
    return texts_padded, labels

def create_data_loader(dataset, batch_size=32, shuffle=True,):
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
