from dataclasses import dataclass
import os
import re
from transformers import AutoTokenizer
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
import torch
import pickle

def get_collate_fn(tokenizer, max_length, clients):
    def collate_fn(batch):
        # Collate function to create batch for DataLoader
        if len(batch) == 0:
            return None

        client_partitions = [item['client_partitions'] for item in batch]
        labels = [item['label'] for item in batch]
        
        # Pad client partitions for each client respectively
        padded_partitions = []
        for client_idx in range(clients):
            client_data = [part[client_idx] for part in client_partitions]
            # Use the tokenizer's padding functionality
            padded_client_data = tokenizer.pad(
                client_data,
                padding='max_length',
                max_length=max_length,
                return_tensors="pt",
            )
            # print(padded_client_data)
            client_input_ids = padded_client_data["input_ids"]
            client_attention_mask = padded_client_data["attention_mask"]
            padded_partitions.append(LLMBatchInput(client_input_ids, client_attention_mask))

        return (padded_partitions,torch.tensor(labels))
    return collate_fn

@dataclass
class LLMBatchInput:
    input_ids: torch.Tensor
    attention_mask: torch.Tensor

    def to(self, device=None, dtype=None):
        self.input_ids = self.input_ids.to(device=device)
        self.attention_mask = self.attention_mask.to(device=device)
        return self
    


class DataPreprocessor:
    def __init__(self, args):
        self.max_length = 512
        self.clients = args.num_clients
        self.batch_size = args.batch_size
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.use_subset = True
        self.dataset = None

    def take_subset(self, n):
        if self.dataset is None:
            raise ValueError("Dataset not loaded. Please load and preprocess the dataset first using 'load_dataset' method.")
        self.dataset['train'] = self.dataset['train'].shuffle(seed=42).select(range(9*(n//10)))
        self.dataset['test'] = self.dataset['test'].shuffle(seed=42).select(range(n//10))

    def partition_tokens(self, token_ids):
        partitions = []
        partition_size = len(token_ids) // self.clients + 1
        for i in range(self.clients):
            start_idx = i * partition_size
            end_idx = (i + 1) * partition_size if i != self.clients - 1 else len(token_ids)
            partitions.append({"input_ids": token_ids[start_idx:end_idx]})
        return partitions

    def create_dataloaders(self):
        if self.dataset is None:
            raise ValueError("Dataset not loaded. Please load and preprocess the dataset first.")

        # Use the existing train and test split from the dataset
        train_dataset = self.dataset['train']
        val_dataset = self.dataset['test']

        # Create PyTorch dataloaders
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=get_collate_fn(self.tokenizer, self.max_length//4 + 1, self.clients))
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=get_collate_fn(self.tokenizer, self.max_length//4 + 1, self.clients))

        return train_loader, val_loader


class AmazonPolarityPreprocessor(DataPreprocessor):
    def __init__(self, args):
        super().__init__(args)
        self.data = args.data

    def load(self):
        # Load amazon_polarity dataset from Hugging Face
        self.dataset = load_dataset(self.data, cache_dir=".cache/huggingface/datasets")
    
    

    def preprocess_text(self, text):    
        tokens = self.tokenizer.tokenize(text, max_length=self.max_length, truncation=True)
        token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        return token_ids
    
    def preprocess_and_partition(self):
        cache_dir = "./datasets/amazon_polarity"
        os.makedirs(cache_dir, exist_ok=True)
        cache_file = os.path.join(cache_dir, "processed_data.pkl")

        # if self.dataset is None:
        #     raise ValueError("Dataset not loaded. Please load and preprocess the dataset first using 'load_dataset' and 'preprocess_and_partition' methods.")


        
        # Check if cached file exists
        if os.path.exists(cache_file):
            print(f"Loading processed data from {cache_file}")
            with open(cache_file, 'rb') as f:
                self.dataset = pickle.load(f)
            return self.dataset
        else:
            self.load()
        
        if self.use_subset:
            self.take_subset(100000)

        def process_example(example):
            token_ids = self.preprocess_text(example['content'])
            if len(token_ids) < 40:
                return {
                    'client_partitions': None,
                    'label': None
                }  # Mark example as invalid if it's too short
            return {
                'client_partitions': self.partition_tokens(token_ids),
                'label': example['label']
            }

        # Apply processing and filter out invalid values explicitly
        self.dataset = self.dataset.map(process_example, remove_columns=['content','title'], batched=False)
        self.dataset = self.dataset.filter(lambda x: x['client_partitions'] is not None)
        
        # Save processed data to cache file
        with open(cache_file, 'wb') as f:
            pickle.dump(self.dataset, f)

        return self.dataset

