from torch.utils.data import Dataset, DataLoader
from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5TokenizerFast as T5Tokenizer,
)
import pandas as pd
from datasets import load_dataset
import pandas as pd
from keybert import KeyBERT
from tqdm import tqdm

TRAIN_DATASET_PATH = ""
EVAL_DATASET_PATH = ""


def clean(keywords):
    source_text = " ".join(map(str, keywords))
    return source_text

def clean_keywords(keywords):
  return clean(list(map(list, zip(*keywords)))[0])

def make_keywords(dataset):
  kw_model = KeyBERT()
  df = pd.DataFrame(columns=['text', 'keywords'])
  df["text"] = dataset["text"] 
  for i in tqdm(range(len(df))):
    keyword = kw_model.extract_keywords(df['text'][i])
    clean = clean_keywords(keyword)
    df["keywords"][i] = clean
  return df

def make_dataset(dataset="common_gen", split="train", cache_dir='src/data/fine_tuning/csr/cache', from_local=False):
    if dataset == "common_gen":
        if not from_local:
            dataset = load_dataset(dataset, split=split)
        else:
            # cache_dir = 'src/data/fine_tuning/csr/cache'
            dataset = load_dataset("src/utils/data_generation/pretrained/data/common_gen.py", split=split, cache_dir=cache_dir) #, download_mode="force_redownload"
        df = pd.DataFrame()
        df["keywords"] = dataset["concepts"]
        df["text"] = dataset["target"]
        df["keywords"] = df["keywords"].apply(clean)
        return df
    else:
        dataset = load_dataset(dataset, split=split)
        print(dataset)
        df = make_keywords(dataset=dataset)
        return df
  

class DataModule(Dataset):
    """
    Data Module for pytorch
    """

    def __init__(
        self,
        data: pd.DataFrame,
        tokenizer: T5Tokenizer,
        source_max_token_len: int = 512,
        target_max_token_len: int = 512,
    ):
        """

        :param data:
        :param tokenizer:
        :param source_max_token_len:
        :param target_max_token_len:
        """
        self.data = data
        self.target_max_token_len = target_max_token_len
        self.source_max_token_len = source_max_token_len
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int):
        data_row = self.data.iloc[index]

        keywords_encoding = self.tokenizer(
            data_row["keywords"],
            max_length=self.source_max_token_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt",
        )

        text_encoding = self.tokenizer(
            data_row["text"],
            max_length=self.target_max_token_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt",
        )

        labels = text_encoding["input_ids"]
        labels[labels == 0] = -100

        return dict(
            keywords=data_row["keywords"],
            text=data_row["text"],
            keywords_input_ids=keywords_encoding["input_ids"].flatten(),
            keywords_attention_mask=keywords_encoding["attention_mask"].flatten(),
            labels=labels.flatten(),
            labels_attention_mask=text_encoding["attention_mask"].flatten(),
        )
        
        
def get_data_loaders(
        train_df: pd.DataFrame,
        test_df: pd.DataFrame,
        val_df: pd.DataFrame,
        tokenizer: T5Tokenizer,
        source_max_token_len: int = 512,
        target_max_token_len: int = 512,
        batch_size: int = 4,
        split: float = 0.1,
    ):
    train_dataset = DataModule(
        train_df,
        tokenizer,
        source_max_token_len,
        target_max_token_len,
    )
    test_dataset = DataModule(
        test_df,
        tokenizer,
        source_max_token_len,
        target_max_token_len,
    )
    
    val_dataset = DataModule(
        val_df,
        tokenizer,
        source_max_token_len,
        target_max_token_len,
    )
    
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    
    return train_dataloader, test_dataloader, val_dataloader