from torch.utils.data import DataLoader, Dataset, ConcatDataset
from datasets import load_from_disk
import torch
import random
from evaluate import load_and_concatenate_datasets,valid_path

class GluePromptDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=128, label_max_length=8):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.label_max_length = label_max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        processed_item = preprocess_data(item, self.tokenizer, self.max_length)
        if self.task_ids is not None:
            processed_item["task_id"] = self.task_ids[idx]
        return processed_item
    
    def compute_task_embedding(self, model, samples_per_task=20, device="cuda"):
        model.to(device)
        model.eval()
        model.config.use_cache = False
        model.config.output_hidden_states = True

        num_samples = min(samples_per_task, len(self))
        sampled_indices = random.sample(range(len(self)), num_samples)

        embeddings = []
        with torch.no_grad():
            for idx in sampled_indices:
                sample = self[idx] 
                input_ids = sample["input_ids"].unsqueeze(0).to(device)         # [1, seq_len]
                attention_mask = sample["attention_mask"].unsqueeze(0).to(device)  # [1, seq_len]
                
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    return_dict=True,
                    output_hidden_states=True, 
                    use_cache=True
                )
                hs = outputs.hidden_states
                last_hidden = hs[-1] if isinstance(hs, (tuple, list)) else hs  # [1, L, H]

                mask = attention_mask.unsqueeze(-1).expand_as(last_hidden).float()
                summed = (last_hidden * mask).sum(dim=1)                     # [1, hidden_size]
                lengths = attention_mask.sum(dim=1, keepdim=True).float()    # [1,1]
                mean_emb = summed / lengths                                  # [1, hidden_size]
                embeddings.append(mean_emb)

        embeddings = torch.cat(embeddings, dim=0)  # [num_samples, hidden_size]
        task_emb = embeddings.mean(dim=0)          # [hidden_size]
        self.task_embedding = task_emb
        return task_emb
    
def compute_task_embedding_t5(raw_dataset, t5_model, t5_tokenizer, samples_per_task=20, device="cuda"): 
    t5_model.to(device) 
    t5_model.eval() 

    num_samples = min(samples_per_task, len(raw_dataset)) 
    sampled_indices = random.sample(range(len(raw_dataset)), num_samples) 

    hidden_states = [] 
    with torch.no_grad(): 
        for idx in sampled_indices: 
            raw_text = raw_dataset[idx]['prompt'] 
            
            inputs = t5_tokenizer(
                raw_text, 
                return_tensors="pt", 
                padding="max_length", 
                truncation=True, 
                max_length=512
            )
            
            input_ids = inputs["input_ids"].to(device) 
            attention_mask = inputs["attention_mask"].to(device)

            outputs = t5_model.encoder(input_ids=input_ids, attention_mask=attention_mask) 
            last_hidden_state = outputs.last_hidden_state 
             
            mask = attention_mask.unsqueeze(-1).expand_as(last_hidden_state) 
            masked_hidden = last_hidden_state * mask 
            sample_embedding = masked_hidden.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True) 
            hidden_states.append(sample_embedding) 

    task_embedding = torch.cat(hidden_states).mean(dim=0) 
    return task_embedding

def preprocess_data(example, tokenizer, max_length=128):
    full_text = example["prompt"] + " " + str(example["label"])

    model_inputs = tokenizer(
        full_text,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )

    prompt_ids = tokenizer(
        example["prompt"],
        add_special_tokens=False,
        return_tensors="pt"
    )["input_ids"]

    input_ids = model_inputs["input_ids"][0]  
    prompt_len = prompt_ids.shape[-1]

    labels = input_ids.clone()
    labels[:prompt_len] = -100  

    return {
        "input_ids": input_ids,
        "attention_mask": model_inputs["attention_mask"][0],
        "labels": labels
    }


from collections import Counter
def load_and_concat_datasets(task_names, tokenizer, t5_encoder,t5_tokenizer, max_length=128, label_max_length=8, samples_per_task=20):
    task_id_map = {task: idx for idx, task in enumerate(task_names)}
    task_embeddings = []
    train_datasets = []
    valid_datasets = {}

    def is_valid_sample(example):
        prompt_ids = tokenizer(example["prompt"], add_special_tokens=False)["input_ids"]
        return len(prompt_ids) < max_length
    
    for task in task_names:
        train_path = DATASET_FILES[task]['train']
        raw_train_ds = load_from_disk(train_path)
        original_train_len = len(raw_train_ds)
        filtered_train_ds = raw_train_ds.filter(is_valid_sample, num_proc=4)
        new_train_len = len(filtered_train_ds)
        if original_train_len > new_train_len:
            print(f"Task {task} (Train): Filtered out {original_train_len - new_train_len} samples due to excessive prompt length.")

        train_dataset = GluePromptDataset(filtered_train_ds, tokenizer, max_length, label_max_length)
        train_dataset.task_ids = torch.tensor([task_id_map[task]] * len(train_dataset), dtype=torch.long)
        print("You have loaded the dataset of {}. Train_Length:{}".format(task, len(train_dataset)))
        
        if t5_encoder is None:
            print("Embedding Model is None, skip computing task embedding")
            task_embedding = torch.zeros((768,), dtype=torch.float16)
            task_embeddings.append(task_embedding)
        else:
            task_embedding = compute_task_embedding_t5(raw_train_ds,t5_encoder,t5_tokenizer, samples_per_task=samples_per_task) 
            task_embeddings.append(task_embedding)
        
        raw_valid_ds = load_and_concatenate_datasets(valid_path[task])
        valid_dataset = GluePromptDataset(raw_valid_ds, tokenizer, max_length, label_max_length)
        valid_dataset.task_ids = torch.tensor([task_id_map[task]] * len(valid_dataset), dtype=torch.long)

        train_datasets.append(train_dataset)
        valid_datasets[task] = raw_valid_ds
    
    combined_train_dataset = ConcatDataset(train_datasets)
    combined_train_dataset.task_ids = torch.cat([d.task_ids for d in train_datasets])

    task_embeddings = torch.stack(task_embeddings)  # [num_tasks, hidden_size]
    return train_datasets, valid_datasets, task_embeddings

def collate_fn(batch):
    return {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
        "labels": torch.stack([item["labels"] for item in batch]),
        "task_id": torch.tensor([item["task_id"] for item in batch], dtype=torch.long)
    }

DATASET_FILES = { }


import random
import numpy as np
class MultiTaskDataLoader:
    def __init__(self, datasets, batch_size, shuffle=True, seed=2025, sampling_weights=None):

        self.datasets = datasets 
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.seed = seed
        random.seed(self.seed)
        np.random.seed(self.seed)

        self.data_sizes = [len(dataset) for dataset in self.datasets]
        self.total_size = sum(self.data_sizes)

        if sampling_weights is None:
            self.dataset_probabilities = np.array(self.data_sizes) / self.total_size
        else:
            assert len(sampling_weights) == len(datasets)
            self.dataset_probabilities = np.array(sampling_weights) / sum(sampling_weights)
        ## rte, mnli, mrpc, sst2, qqp, qnli, cola, stsb
        ## boolq,obqa,piqa,arc_e,arc_c,siqa,winogrande,hellaswag
        self.dataset_probabilities = np.array(self.data_sizes) / sum(self.data_sizes)
        self.dataset_probabilities = np.exp(self.dataset_probabilities) / np.sum(np.exp(self.dataset_probabilities))

        self.dataset_loaders = [
            DataLoader(
                dataset,
                batch_size=1, 
                shuffle=True,
                num_workers=4,
                pin_memory=True
            ) for dataset in self.datasets
        ]
        self.dataset_iters = [iter(loader) for loader in self.dataset_loaders]

    def __iter__(self):
        self.current_step = 0
        self.dataset_iters = [iter(loader) for loader in self.dataset_loaders] 
        return self

    def __next__(self):
        if self.current_step * self.batch_size >= self.total_size:
            raise StopIteration

        samples = []
        for _ in range(self.batch_size):
            dataset_idx = np.random.choice(len(self.datasets), p=self.dataset_probabilities)
            try:
                sample = next(self.dataset_iters[dataset_idx])
            except StopIteration:
                self.dataset_iters[dataset_idx] = iter(self.dataset_loaders[dataset_idx])
                sample = next(self.dataset_iters[dataset_idx])
            samples.append(sample)

        batch = {
            "input_ids": torch.cat([s["input_ids"] for s in samples], dim=0),
            "attention_mask": torch.cat([s["attention_mask"] for s in samples], dim=0),
            "labels": torch.cat([s["labels"] for s in samples], dim=0),
            "task_ids":torch.cat([s['task_id'] for s in samples])
        }
        self.current_step += 1
        return batch

    def __len__(self):
        return self.total_size // self.batch_size


from torch.utils.data import ConcatDataset, DataLoader, WeightedRandomSampler
class MultiTaskSamplerDataLoader:
    def __init__(
        self,
        datasets,               
        batch_size: int,
        sampling_weights=None,
        seed: int = 2025,
        num_workers: int = 4,
        pin_memory: bool = True
    ):
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        data_sizes = [len(ds) for ds in datasets]
        total_size = 8 * 20000
        probs = np.array(data_sizes) / sum(data_sizes)
        # softmax
        probs = np.exp(probs) / np.sum(np.exp(probs))
        if sampling_weights is not None:
            assert len(sampling_weights) == len(datasets)
            w = np.array(sampling_weights)
            probs = np.exp(w) / np.sum(np.exp(w))

        concat_ds = ConcatDataset(datasets)
        sample_weights = []
        for i, size in enumerate(data_sizes):
            per_sample_w = probs[i] / size
            sample_weights += [per_sample_w] * size
        sample_weights = torch.tensor(sample_weights, dtype=torch.double)

        sampler = WeightedRandomSampler(
            weights=sample_weights,
            num_samples=total_size,  
            replacement=True
        )
        self.loader = DataLoader(
            concat_ds,
            batch_size=batch_size,
            sampler=sampler,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory
        )

    def __iter__(self):
        return iter(self.loader)

    def __len__(self):
        return len(self.loader)

        