import torch
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from myutils import print_debug
import re
from torch.cuda.amp import autocast, GradScaler

from tqdm import tqdm
########################## encode function and text dataset ##############################
class TextDataset(Dataset):
    def __init__(self, encoded_chunks):
        self.encoded_chunks = encoded_chunks

    def __len__(self):
        return len(self.encoded_chunks)

    def __getitem__(self, idx):
        return self.encoded_chunks[idx]


def encode(texts, tokenizer, MAX_LEN=1024):
    # encode the texts into chunks of MAX_LEN texts.
    # texts = [text for text in texts if len(text) > 50 and not text.isspace()]
    
    concatenated_text = "\n\n".join(texts)
    encodings = tokenizer(concatenated_text, return_tensors='pt')

    # print the length of the input
    print(f'input length: {encodings["input_ids"].shape[1]}')
    
    # Now split these long encodings into smaller chunks of max_len
    input_ids_chunks = encodings['input_ids'][0].split(MAX_LEN)
    attention_mask_chunks = encodings['attention_mask'][0].split(MAX_LEN)
    
    print(f'input_ids_chunks: {len(input_ids_chunks)}, attention_mask_chunks: {len(attention_mask_chunks)}')

    return_list = [{'input_ids': chunk, 'attention_mask': mask} for chunk, mask in zip(input_ids_chunks, attention_mask_chunks)]
    
    # drop the last one
    if len(return_list) > 1:
        return return_list[:-1]
    else:
        return return_list


# def downsample_dataset(train_texts, val_texts, batch_size_val, downsample_rate=1.0):
#     # get the dataloader from the texts, and downsample before encode (can make it faster for very large dataset)
    
#     # Downsampling
#     downsample_num_train = int(len(train_texts) * downsample_rate)
#     downsample_num_val = int(len(val_texts) * downsample_rate)
    
#     random_idx_train = np.random.choice(len(train_texts), downsample_num_train, replace=False)
#     random_idx_val = np.random.choice(len(val_texts), max(downsample_num_val, batch_size_val), replace=False)
    
#     # Create smaller datasets based on the random indices
#     small_train_texts = [train_texts[i] for i in random_idx_train]
#     small_val_texts = [val_texts[i] for i in random_idx_val]
    
#     return small_train_texts, small_val_texts

def downsample_dataset(texts, minimum_num=None, downsample_rate=1.0):
    # get the dataloader from the texts, and downsample before encode (can make it faster for very large dataset)
    
    # Downsampling
    downsample_num = int(len(texts) * downsample_rate)
    
    if minimum_num is not None:
        downsample_num = max(downsample_num, minimum_num)
    
    assert downsample_num <= len(texts), f"downsample_num should be less than the length of the texts, but got {downsample_num} and {len(texts)}, downsample_rate: {downsample_rate}, minimum_num: {minimum_num}"    
    random_idx = np.random.choice(len(texts), downsample_num, replace=False)
    
    # Create smaller datasets based on the random indices
    small_texts = [texts[i] for i in random_idx]
    
    return small_texts



def get_datasets_for_SFT(train_texts, val_texts, batch_size_val, tokenizer, downsample_rate=1.0):
    # small_train_texts, small_val_texts = downsample_dataset(train_texts, val_texts, batch_size_val, downsample_rate)
    small_train_texts = downsample_dataset(train_texts, downsample_rate=downsample_rate)
    small_val_texts = downsample_dataset(val_texts, minimum_num=batch_size_val, downsample_rate=downsample_rate)
    
    
    # add tokenizer.eos_token for each text at the end
    small_train_texts = [text + tokenizer.eos_token for text in small_train_texts]
    small_val_texts = [text + tokenizer.eos_token for text in small_val_texts]
    
    # get dataset_train, dataset_val using Dataset class
    from datasets import Dataset
    dataset_train = Dataset.from_dict({'text': small_train_texts})
    dataset_val = Dataset.from_dict({'text': small_val_texts})
    
    print(f'=========== train_data: {len(small_train_texts)}, val_data: {len(small_val_texts)} ===========')
    
    return dataset_train, dataset_val


def get_encoded_dataloader_from_texts(texts, batch_size, tokenizer, MAX_LEN, downsample_rate=1.0, is_val=False):
    if is_val:
        small_texts = downsample_dataset(texts, minimum_num=batch_size, downsample_rate=downsample_rate)
    else:
        small_texts = downsample_dataset(texts, downsample_rate=downsample_rate)
    
    # add tokenizer.eos_token for each text at the end
    print(f'@@@@@@@@@@@@ Have added eos_token @@@@@@@@@@@@@@')
    small_texts = [text + tokenizer.eos_token for text in small_texts]
    
    # Encode
    small_encoded = encode(small_texts, tokenizer, MAX_LEN)
    
    # Create DataLoaders
    is_shuffle=True if not is_val else False
    loader = DataLoader(TextDataset(small_encoded), batch_size=batch_size, shuffle=is_shuffle, num_workers=8, pin_memory=True)
    
    print(f"=========== {'Val' if is_val else 'Train'} loader: {len(loader)}, downsampled data: {len(small_texts)}, is_shuffle: {is_shuffle}===========")
    
    return loader


def get_train_val_dataloader_from_texts(train_texts, val_texts, batch_size, batch_size_val=1, tokenizer=None, MAX_LEN=1024, downsample_rate=1.0):
    train_loader = get_encoded_dataloader_from_texts(train_texts, batch_size, tokenizer, MAX_LEN, downsample_rate=downsample_rate, is_val=False)
    val_loader = get_encoded_dataloader_from_texts(val_texts, 1, tokenizer, MAX_LEN, downsample_rate=downsample_rate, is_val=True)
    
    return train_loader, val_loader

# def get_train_val_dataloader_from_texts(train_texts, val_texts, batch_size, batch_size_val, tokenizer, MAX_LEN, downsample_rate=1.0):
#     # small_train_texts, small_val_texts = downsample_dataset(train_texts, val_texts, batch_size_val, downsample_rate)
#     small_train_texts = downsample_dataset(train_texts, downsample_rate=downsample_rate)
#     small_val_texts = downsample_dataset(val_texts, minimum_num=batch_size_val, downsample_rate=downsample_rate)
    
#     # Encode
#     small_train_encoded = encode(small_train_texts, tokenizer, MAX_LEN)
#     small_val_encoded = encode(small_val_texts, tokenizer, MAX_LEN)

#     # Create DataLoaders
#     train_loader = DataLoader(TextDataset(small_train_encoded), batch_size=batch_size, shuffle=True)
#     val_loader = DataLoader(TextDataset(small_val_encoded), batch_size=batch_size_val, shuffle=False)

#     print(f'=========== train_loader: {len(train_loader)}, val_loader: {len(val_loader)}, downsampled train_data: {len(small_train_texts)}, val_data: {len(small_val_texts)} ===========')
    
#     return train_loader, val_loader



############## load the dataset ################
class DatasetManager:
    # manager for getting the dataset texts
    def __init__(self):

        self.all_names = ['wikitext', 'arxiv-math', 'dialogsum', 'alpaca-gpt4', 'databricks-dolly-15k', 'OpenOrca', 'gsm8k', 'fineweb', 'fineweb-edu']
        self.test_types = self.all_names + ['all']


    
    def get_dataset_texts(self, dataset_names, test_type = 'wikitext', original_skip_size=0):
        # get datasets. dataset_names can be a list of dataset names concatenated by '::', like 'wikitext::arxiv_math'
        
        # if dataset_names is 'all', then use all the datasets
        if dataset_names == 'all':
            dataset_names = '::'.join(self.all_names)
        
        # first split the dataset_names
        dataset_names_list = dataset_names.split('::')
        print(f'dataset_names_list before test: {dataset_names_list}')
        if test_type in self.all_names and test_type not in dataset_names_list:
            dataset_names_list.append(test_type) # add the test_type to the dataset_names_list
        
        print(f'dataset_names_list after test: {dataset_names_list}')
        
        assert all([name in self.all_names for name in dataset_names_list]), f"dataset_names should be one of {self.all_names} or 'all'"
        
        # get the train, val, test texts for each dataset
        train_texts_dict, val_texts_dict, test_texts_dict = {}, {}, {}
        for name in dataset_names_list:
            train_texts_dict[name], val_texts_dict[name], test_texts_dict[name] = self.get_dataset_texts_util(name, original_skip_size=original_skip_size)
            
        # merge the train texts because we want to use all of them
        final_train_texts = []
        for key in dataset_names.split('::'):
            final_train_texts += train_texts_dict[key]
        

        # get the test and val texts based on the test_type
        if test_type == 'default' and len(dataset_names_list) == 1:
            test_type = dataset_names_list[0]
            
        final_test_texts = self.get_test_texts_from_all(test_texts_dict, test_type)
        final_val_texts = self.get_test_texts_from_all(val_texts_dict, test_type)
        
        print(f'len of train, val, test: {len(final_train_texts)}, {len(final_val_texts)}, {len(final_test_texts)}')
        
        # random shuffle final_train_texts
        np.random.shuffle(final_train_texts)
        
        return final_train_texts, final_val_texts, final_test_texts
        
        
        
    
    def get_test_texts_from_all(self, test_texts_dict, test_type):
        # get the test texts based on the test_type
        
        assert test_type in self.test_types, f"test_type should be one of {self.test_types}, but got {test_type}"
        
        if test_type == 'all':
            # concatenate all the test_texts_dict
            all_test_texts = []
            for key in test_texts_dict.keys():
                all_test_texts += test_texts_dict[key]
                
            return all_test_texts
        else:
            return test_texts_dict[test_type]
        
    
    
    def get_dataset_texts_util(self, dataset_name, original_skip_size=0):

        if dataset_name == 'wikitext':
            train_texts_dict, val_texts_dict, test_texts_dict = self.get_wikitext_all()
        elif dataset_name == 'arxiv-math':
            train_texts_dict, val_texts_dict, test_texts_dict = self.get_arxiv_math_all()
        elif dataset_name == 'dialogsum':
            train_texts_dict, val_texts_dict, test_texts_dict = self.get_dialogsum_all()
        elif dataset_name == 'alpaca-gpt4':
            train_texts_dict, val_texts_dict, test_texts_dict = self.get_alpaca_gpt4_all()
        elif dataset_name == 'databricks-dolly-15k':
            train_texts_dict, val_texts_dict, test_texts_dict = self.get_databricks_dolly_15k_all()
        elif dataset_name == 'OpenOrca':
            train_texts_dict, val_texts_dict, test_texts_dict = self.get_OpenOrca_all()
        elif dataset_name == 'gsm8k':
            train_texts_dict, val_texts_dict, test_texts_dict = self.get_gsm8k_all()
        elif dataset_name == 'fineweb':
            train_texts_dict, val_texts_dict, test_texts_dict = self.get_fineweb_all(original_skip_size=original_skip_size)
        elif dataset_name == 'fineweb-edu':
            train_texts_dict, val_texts_dict, test_texts_dict = self.get_fineweb_edu_all(original_skip_size=original_skip_size)
        
        print(f'{dataset_name} example: {train_texts_dict[0]}')
        
        return train_texts_dict, val_texts_dict, test_texts_dict

    def get_online_dataset_generator(self, dataset_name, subset_size=40000, subset_num=3, splits=(0.99, 0.01)):
        # get the online dataset, and return the generator for the dataset
        
        if dataset_name == 'fineweb':
            generator = self.get_fineweb_online(subset_size=subset_size, subset_num=subset_num, splits=splits)
        elif dataset_name == 'fineweb-edu':
            generator = self.get_fineweb_edu_online(subset_size=subset_size, subset_num=subset_num, splits=splits)
        else:
            raise ValueError(f"dataset_name should be 'fineweb', but got {dataset_name}")
        
        return generator
        
    
    def get_wikitext_all(self):
        train_texts = load_dataset('Salesforce/wikitext', 'wikitext-2-raw-v1', split="train")['text']
        val_texts = load_dataset('Salesforce/wikitext', 'wikitext-2-raw-v1', split="test")['text']
        test_texts = val_texts
        # test_encodings = self.tokenizer("\n\n".join(test_texts), return_tensors='pt')
        
        return train_texts, val_texts, test_texts
    
    def get_arxiv_math_all(self):
        template = "Question: {}\nAnswer: {}"
        all_data = load_dataset('ArtifactAI/arxiv-math-instruct-50k', split="train")
        dataset_list = [template.format(item['question'], item['answer']) for item in all_data]
        
        # split train and test randomly, ratio 0.99 vs 0.01. Val is also the test set
        train_ratio, test_ratio = 0.99, 0.01
        train_num = int(len(dataset_list) * train_ratio)
        
        # random shuffle
        np.random.shuffle(dataset_list)
        train_texts = dataset_list[:train_num]
        test_texts = dataset_list[train_num:]

        # not need downsample the test set
        val_texts = test_texts
        # test_encodings = self.tokenizer("\n".join(test_texts), return_tensors='pt')

        return train_texts, val_texts, test_texts
    
    def get_dialogsum_all(self):
        template = "Write the summary and the topic based on the dialogue.\nDialogue: {}\nSummary: {}\nTopic: {}"
        train_data = load_dataset('knkarthick/dialogsum', split="train") 
        test_data = load_dataset('knkarthick/dialogsum', split="test")

        train_texts = [template.format(item['dialogue'], item['summary'], item['topic']) for item in train_data]
        test_texts = [template.format(item['dialogue'], item['summary'], item['topic']) for item in test_data]

        # random downsampling the test set to 0.05
        np.random.shuffle(test_texts)
        test_texts = test_texts[:int(len(test_texts) * 0.01)]

        val_texts = test_texts
        # test_encodings = self.tokenizer("\n".join(test_texts), return_tensors='pt')
        
        return train_texts, val_texts, test_texts
    
    def get_alpaca_gpt4_all(self):
        all_data = load_dataset('vicgalle/alpaca-gpt4', split="train")
        
        template = "{}\n"
        dataset_list = [template.format(item['text']) for item in all_data]
        
        # replace all the "### Instruction" with "Instruction", "### Input" with "Input", "### Response" with "Response"
        dataset_list = [re.sub(r"### (Instruction|Input|Response)", r"\1", item) for item in dataset_list] # remove the "###"

        train_ratio, test_ratio = 0.99, 0.01
        train_num = int(len(dataset_list) * train_ratio)

        # random shuffle
        np.random.shuffle(dataset_list)
        train_texts = dataset_list[:train_num]
        test_texts = dataset_list[train_num:]

        # not need downsample the test set
        val_texts = test_texts
        # test_encodings = self.tokenizer("\n".join(test_texts), return_tensors='pt')
        
        return train_texts, val_texts, test_texts
    
    def get_databricks_dolly_15k_all(self):
        template_part1 = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\nInstruction: {}\n"
        template_part2 = "Context: {}\n" # None if no context
        template_part3 = "Response: {}\n"
        
        all_data = load_dataset("databricks/databricks-dolly-15k", split="train")
        dataset_list = [template_part1.format(item['instruction']) + (template_part2.format(item['context']) if item['context'] is not None else "") + template_part3.format(item['response']) for item in all_data]
        
        train_ratio, test_ratio = 0.99, 0.01
        train_num = int(len(dataset_list) * train_ratio)
        
        # random shuffle
        np.random.shuffle(dataset_list)
        train_texts = dataset_list[:train_num]
        test_texts = dataset_list[train_num:]
        
        val_texts = test_texts
        # test_encodings = self.tokenizer("\n".join(test_texts), return_tensors='pt')
        
        return train_texts, val_texts, test_texts
    
    # original 
    # def get_OpenOrca_all(self):
    #     template = "{}\nQuestion: {}\nResponse{}\n"
    #     all_data = load_dataset('Open-Orca/OpenOrca', split="train")
    #     dataset_list = [template.format(item['system_prompt'], item['question'], item['response']) for item in all_data]
                
    #     train_ratio = 0.99999
    #     train_num = int(len(dataset_list) * train_ratio)
        
    #     # random shuffle
    #     np.random.shuffle(dataset_list)
    #     train_texts = dataset_list[:train_num]
    #     test_texts = dataset_list[train_num:]

    #     # not need downsample the test set
    #     val_texts = test_texts
    #     # test_encodings = self.tokenizer("\n".join(test_texts), return_tensors='pt')
        
    #     return train_texts, val_texts, test_texts

    def get_OpenOrca_all(self):
        template = "{}\nQuestion: {}\nResponse{}\n"
        # all_data = load_dataset('Open-Orca/OpenOrca', split="train")
        # select the first 50k by streaming
        all_data = load_dataset('Open-Orca/OpenOrca', split="train", streaming=True).take(50000)
        
        dataset_list = [template.format(item['system_prompt'], item['question'], item['response']) for item in all_data]
        
        train_ratio = 0.995
        train_num = int(len(dataset_list) * train_ratio)
        
        # random shuffle
        np.random.shuffle(dataset_list)
        train_texts = dataset_list[:train_num]
        test_texts = dataset_list[train_num:]

        # not need downsample the test set
        val_texts = test_texts
        # test_encodings = self.tokenizer("\n".join(test_texts), return_tensors='pt')
        
        return train_texts, val_texts, test_texts
    
    def get_gsm8k_all(self):
        prompt_prefix = "Answer the following question.\n\n"
        
        template = "Question: {}\nAnswer: {}\n" # NOTE: the open-instruct use \n\n to judge the end of output, so the training matters!
        all_data = load_dataset("openai/gsm8k", "main", split="train")
        all_data_list = []
        
        # follow open-instruct
        for data in all_data:
            all_data_list.append({
                "question": data['question'],
                # "answer": data['answer'].split("####")[1].strip()
                "answer": data['answer']
            })
        
        # # some numbers are in the `x,xxx` format, and we want to remove the comma
        # for example in all_data_list:
        #     ori_answer = example["answer"]
        #     example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
        #     # print(f"answer: {example['answer']}, ori_answer: {ori_answer}")
        #     assert float(example["answer"]) or (float(example["answer"]) == 0.), f"answer is not a valid number: {example['answer']}, ori_answer: {ori_answer}"

        
        dataset_list = [prompt_prefix + template.format(item['question'].strip(), item['answer']) for item in all_data_list]
        
        # split train and test randomly, ratio 0.99 vs 0.01. Val is also the test set
        train_ratio, test_ratio = 0.99, 0.01
        train_num = int(len(dataset_list) * train_ratio)
        
        # random shuffle
        np.random.shuffle(dataset_list)
        train_texts = dataset_list[:train_num]
        test_texts = dataset_list[train_num:]

        # not need downsample the test set
        val_texts = test_texts
        # test_encodings = self.tokenizer("\n".join(test_texts), return_tensors='pt')

        return train_texts, val_texts, test_texts      

    
    # since fineweb is too large, we use stream mode for loading the dataset and just containing the first pre_load size dataset
    def get_fineweb_all(self, pre_load = 100000, original_skip_size=0):
        print(f'############## Original skip size = {original_skip_size} #################')
        template = "{}\n"
        
        fw = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=True) 
        
        all_data = fw.skip(original_skip_size).take(pre_load)
        
        dataset_list = [template.format(item['text']) for item in all_data]

        train_ratio, test_ratio = 0.999, 0.001
        train_num = int(len(dataset_list) * train_ratio)

        # random shuffle
        np.random.shuffle(dataset_list)
        train_texts = dataset_list[:train_num]
        test_texts = dataset_list[train_num:]

        # not need downsample the test set
        val_texts = test_texts
        # test_encodings = self.tokenizer("\n".join(test_texts), return_tensors='pt')
        
        return train_texts, val_texts, test_texts
    
    
    # similar to get_fineweb_all, but we use fineweb-edu dataset
    def get_fineweb_edu_all(self, pre_load = 100000, original_skip_size=0):
        print(f'############## Original skip size = {original_skip_size} #################')
        template = "{}\n"
        
        fw = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True) 
        
        all_data = fw.skip(original_skip_size).take(pre_load)
        
        dataset_list = [template.format(item['text']) for item in all_data]

        train_ratio, test_ratio = 0.999, 0.001
        train_num = int(len(dataset_list) * train_ratio)

        # random shuffle
        np.random.shuffle(dataset_list)
        train_texts = dataset_list[:train_num]
        test_texts = dataset_list[train_num:]

        # not need downsample the test set
        val_texts = test_texts
        # test_encodings = self.tokenizer("\n".join(test_texts), return_tensors='pt')
        
        return train_texts, val_texts, test_texts
    
    # # just support online reading for fineweb now
    # def get_fineweb_dataloaders_online(
    #     self,
    #     # encode
    #     batch_size, 
    #     batch_size_val,
    #     tokenizer, 
    #     max_len, 
    #     ratio = 1.0,
    #     global_test_texts_size=200,
    #     # subset splits
    #     subset_size=30000, 
    #     subset_num=4, 
    #     splits=(0.95, 0.05)
    # ):
    #     # print all the local variables

    #     template = "{}\n"
    #     fw = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=True)
        
    #     train_end = int(subset_size * splits[0])
        
    #     global_test_texts = fw.take(global_test_texts_size)
        
    #     def dataloader_generator():
    #         nonlocal fw, subset_num, subset_size, train_end, batch_size, batch_size_val, tokenizer, max_len, ratio
            
    #         count = 0 # count the number of batches
            
    #         while count < subset_num:
    #             all_data = fw.take(subset_size)
                
    #             # support multiple epochs
    #             if all_data is None:
    #                 fw = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=True)
    #                 continue
                
    #             dataset_list = [template.format(item['text']) for item in all_data]
    #             np.random.shuffle(dataset_list)
                
    #             train_texts = dataset_list[:train_end]
    #             test_texts = dataset_list[train_end:]
                
    #             train_loader = get_encoded_dataloader_from_texts(train_texts, batch_size, tokenizer, max_len, downsample_rate=ratio, is_val=False)
    #             # not use batch_size_val here
    #             test_loader = get_encoded_dataloader_from_texts(test_texts, 1, tokenizer, max_len, downsample_rate=ratio, is_val=True)
                
    #             yield train_loader, test_loader
    #             count += 1
    #             print(f'count = {count}, len of train, test: {len(train_texts)}, {len(test_texts)}, dataloader: {len(train_loader)}, {len(test_loader)}')
        
    #     return dataloader_generator(), global_test_texts
    
    # just support online reading for fineweb now
    def get_fineweb_dataloaders_online(
        self,
        # encode
        batch_size, 
        batch_size_val,
        tokenizer, 
        max_len, 
        ratio = 1.0,
        global_test_texts_size=200,
        # subset splits
        original_skip_size=0,
        subset_size=30000, 
        subset_num=4, 
        splits=(0.95, 0.05)
    ):
        # print all the local variables

        template = "{}\n"
        fw = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=True)
        
        train_end = int(subset_size * splits[0])
        
        global_test_texts = fw.take(global_test_texts_size)
        skip_size = original_skip_size + global_test_texts_size
        
        
        
        def dataloader_generator():
            nonlocal fw, subset_num, subset_size, train_end, batch_size, batch_size_val, tokenizer, max_len, ratio, skip_size, global_test_texts_size, original_skip_size
            
            count = 0 # count the number of batches
            
            print(f'*********** skip size = {skip_size}, global_test_texts_size = {global_test_texts_size}, original_skip_size = {original_skip_size}')
            
            while count < subset_num:
                all_data = fw.skip(skip_size).take(subset_size)
                
                # support multiple epochs
                if all_data is None:
                    fw = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=True)
                    skip_size = global_test_texts_size # reset the skip size
                    print(f'*********** reset skip size = {skip_size}')
                    continue
                
                dataset_list = [template.format(item['text']) for item in all_data]
                np.random.shuffle(dataset_list)
                
                train_texts = dataset_list[:train_end]
                test_texts = dataset_list[train_end:]
                
                
                train_loader = get_encoded_dataloader_from_texts(train_texts, batch_size, tokenizer, max_len, downsample_rate=ratio, is_val=False)
                # not use batch_size_val here
                test_loader = get_encoded_dataloader_from_texts(test_texts, 1, tokenizer, max_len, downsample_rate=ratio, is_val=True)
                
                yield train_loader, test_loader
                count += 1
                skip_size += subset_size
                
                print(f'count = {count}, len of train, test: {len(train_texts)}, {len(test_texts)}, dataloader: {len(train_loader)}, {len(test_loader)}, skip_size = {skip_size}')
        
        return dataloader_generator(), global_test_texts
    
    # just support online reading for fineweb now
    def get_fineweb_edu_dataloaders_online(
        self,
        # encode
        batch_size, 
        batch_size_val,
        tokenizer, 
        max_len, 
        ratio = 1.0,
        global_test_texts_size=200,
        # subset splits
        original_skip_size=0,
        subset_size=30000, 
        subset_num=4, 
        splits=(0.95, 0.05)
    ):
        # print all the local variables

        template = "{}\n"
        fw = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True)
        
        train_end = int(subset_size * splits[0])
        
        global_test_texts = fw.take(global_test_texts_size)
        skip_size = original_skip_size + global_test_texts_size
        
        
        
        def dataloader_generator():
            nonlocal fw, subset_num, subset_size, train_end, batch_size, batch_size_val, tokenizer, max_len, ratio, skip_size, global_test_texts_size, original_skip_size
            
            count = 0 # count the number of batches
            
            print(f'*********** skip size = {skip_size}, global_test_texts_size = {global_test_texts_size}, original_skip_size = {original_skip_size}')
            
            while count < subset_num:
                all_data = fw.skip(skip_size).take(subset_size)
                
                # support multiple epochs
                if all_data is None:
                    fw = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True)
                    skip_size = global_test_texts_size # reset the skip size
                    print(f'*********** reset skip size = {skip_size}')
                    continue
                
                dataset_list = [template.format(item['text']) for item in all_data]
                np.random.shuffle(dataset_list)
                
                train_texts = dataset_list[:train_end]
                test_texts = dataset_list[train_end:]
                
                
                train_loader = get_encoded_dataloader_from_texts(train_texts, batch_size, tokenizer, max_len, downsample_rate=ratio, is_val=False)
                # not use batch_size_val here
                test_loader = get_encoded_dataloader_from_texts(test_texts, 1, tokenizer, max_len, downsample_rate=ratio, is_val=True)
                
                yield train_loader, test_loader
                count += 1
                skip_size += subset_size
                
                print(f'count = {count}, len of train, test: {len(train_texts)}, {len(test_texts)}, dataloader: {len(train_loader)}, {len(test_loader)}, skip_size = {skip_size}')
        
        return dataloader_generator(), global_test_texts

#####################################################
################### for warmup ######################
#####################################################


def get_mlp_io(layer_idx, inputs, outputs):
    def hook(model, input, output):
        inputs[layer_idx] = input[0].detach()
        outputs[layer_idx] = output.detach()
    return hook

@torch.no_grad()
def extract_layers_mlp_ios(model, dataloader, device, start_layer_idx=0, end_layer_idx=31):
    
    all_inputs = {layer_idx: [] for layer_idx in range(start_layer_idx, end_layer_idx + 1)}
    all_outputs = {layer_idx: [] for layer_idx in range(start_layer_idx, end_layer_idx + 1)}

    
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Extracting MLP inputs and outputs'):
            # new activations
            batch_inputs, batch_outputs = {}, {}
            for layer_idx in range(start_layer_idx, end_layer_idx + 1):
                model.model.layers[layer_idx].mlp.register_forward_hook(get_mlp_io(layer_idx, batch_inputs, batch_outputs))
            
            # process data
            # input_ids = batch['input_ids'].to(device)
            # attention_mask = batch['attention_mask'].to(device)
            # _ = model(input_ids, attention_mask=attention_mask)

            # forward pass
            with autocast(enabled=True):  # Enable mixed precision for acceleration
                input_ids = batch['input_ids'].to(device, non_blocking=True)
                attention_mask = batch['attention_mask'].to(device, non_blocking=True)
                _ = model(input_ids, attention_mask=attention_mask)
            
            
            for layer_idx in range(start_layer_idx, end_layer_idx + 1):
                all_inputs[layer_idx].append(batch_inputs[layer_idx].cpu())
                all_outputs[layer_idx].append(batch_outputs[layer_idx].cpu())
                
                # both max_token x dim
                # print(f'layer {layer_idx}, input: {batch_inputs[layer_idx].shape}, output: {batch_outputs[layer_idx].shape}')
                
            # clean up
            del input_ids, attention_mask, batch_inputs, batch_outputs
            torch.cuda.empty_cache()
            
    final_inputs = {k: torch.cat(v) for k, v in all_inputs.items()}
    final_outputs = {k: torch.cat(v) for k, v in all_outputs.items()}
    
    # for layer_idx in range(start_layer_idx, end_layer_idx + 1):
        # print(f"Final shape for input of layer {layer_idx}: {final_inputs[layer_idx].shape}, output: {final_outputs[layer_idx].shape}")
    
    print(f"Final shape for input of layer {layer_idx}: {final_inputs[layer_idx].shape}, output: {final_outputs[layer_idx].shape}")

    return final_inputs, final_outputs


@torch.no_grad()
def extract_layers_mlp_ios_fast(model, dataloader, device, start_layer_idx=0, end_layer_idx=31):
    
    all_inputs = {layer_idx: [] for layer_idx in range(start_layer_idx, end_layer_idx + 1)}
    all_outputs = {layer_idx: [] for layer_idx in range(start_layer_idx, end_layer_idx + 1)}

    
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Extracting MLP inputs and outputs'):
            # new activations
            batch_inputs, batch_outputs = {}, {}
            for layer_idx in range(start_layer_idx, end_layer_idx + 1):
                model.model.layers[layer_idx].mlp.register_forward_hook(get_mlp_io(layer_idx, batch_inputs, batch_outputs))
            
            # process data
            # input_ids = batch['input_ids'].to(device)
            # attention_mask = batch['attention_mask'].to(device)
            # _ = model(input_ids, attention_mask=attention_mask)

            # forward pass
            with autocast(enabled=True):  # Enable mixed precision for acceleration
                input_ids = batch['input_ids'].to(device, non_blocking=True)
                attention_mask = batch['attention_mask'].to(device, non_blocking=True)
                _ = model(input_ids, attention_mask=attention_mask)
            
            
            for layer_idx in range(start_layer_idx, end_layer_idx + 1):
                all_inputs[layer_idx].append(batch_inputs[layer_idx].cpu())
                all_outputs[layer_idx].append(batch_outputs[layer_idx].cpu())
                
                # both max_token x dim
                # print(f'layer {layer_idx}, input: {batch_inputs[layer_idx].shape}, output: {batch_outputs[layer_idx].shape}')
                
            # clean up
            del input_ids, attention_mask, batch_inputs, batch_outputs
            torch.cuda.empty_cache()
    
    # for layer_idx in range(start_layer_idx, end_layer_idx + 1):
        # print(f"Final shape for input of layer {layer_idx}: {final_inputs[layer_idx].shape}, output: {final_outputs[layer_idx].shape}")
    
    # delete the last element in every list to prevent non-equal length
    for layer_idx in range(start_layer_idx, end_layer_idx + 1):
        # if len(all_inputs[layer_idx]) == 1 then we need to keep it
        if len(all_inputs[layer_idx]) > 1:
            all_inputs[layer_idx] = all_inputs[layer_idx][:-1]
            all_outputs[layer_idx] = all_outputs[layer_idx][:-1]
        
    print(f"Final length for input of layer {layer_idx}: {len(all_inputs[layer_idx])}, output: {len(all_outputs[layer_idx])}")
    # show the shape of the last element
    print(f"Final shape for input of layer {layer_idx}: {all_inputs[layer_idx][-1].shape}, output: {all_outputs[layer_idx][-1].shape}")

    return all_inputs, all_outputs


class MLPIODataset(Dataset):
    def __init__(self, inputs, outputs):
        """
        inputs: Tensor containing the input features for the MLP.
        outputs: Tensor containing the target outputs for the MLP.
        """
        if isinstance(inputs, list):
            assert len(inputs) == len(outputs), "The number of inputs must match the number of outputs."
        else:
            assert inputs.size(0) == outputs.size(0), "The number of inputs must match the number of outputs."
            
        self.inputs = inputs
        self.outputs = outputs

    def __len__(self):
        if isinstance(self.inputs, list):
            return len(self.inputs)
        else:
            return self.inputs.size(0)

    def __getitem__(self, idx):
        return self.inputs[idx], self.outputs[idx]
    
################# use a dataset obtained from the dataset generator to generate the mlp inputs/outputs using teacher model
def generate_mlp_ios_from_dataloaders(dataloader, teacher_model, batch_size, device = 'cuda:0', start_layer_idx=0, end_layer_idx=31, is_val=False):
    
    # 1. get the dataloader for encoded dataset
    # dataloader = get_encoded_dataloader_from_texts(texts, batch_size, tokenizer, max_len, downsample_rate=ratio, is_val=False)
    
    # 2. get the intermediate input/output for the mlp layers
    with torch.no_grad():
        final_inputs, final_outputs = extract_layers_mlp_ios(teacher_model, dataloader, device, start_layer_idx=start_layer_idx, end_layer_idx=end_layer_idx)

    # 3. get the intermediate dataloader for the mlp layers
    intermediate_dataloaders = {}
    for k in final_inputs.keys():
        mlp_dataset_k = MLPIODataset(final_inputs[k], final_outputs[k])
        intermediate_dataloaders[k] = DataLoader(mlp_dataset_k, batch_size=batch_size, shuffle= not is_val)
        
    return intermediate_dataloaders


def generate_mlp_ios_from_dataloaders_fast(dataloader, teacher_model, batch_size, device = 'cuda:0', start_layer_idx=0, end_layer_idx=31, is_val=False):
    
    # 1. get the dataloader for encoded dataset
    # dataloader = get_encoded_dataloader_from_texts(texts, batch_size, tokenizer, max_len, downsample_rate=ratio, is_val=False)
    
    # 2. get the intermediate input/output for the mlp layers
    with torch.no_grad():
        all_inputs, all_outputs = extract_layers_mlp_ios_fast(teacher_model, dataloader, device, start_layer_idx=start_layer_idx, end_layer_idx=end_layer_idx)

    # 3. get the intermediate dataloader for the mlp layers
    intermediate_dataloaders = {}
    for k in all_inputs.keys():
        mlp_dataset_k = MLPIODataset(all_inputs[k], all_outputs[k])
        # print(f'k: {k}, input: {len(all_inputs[k])}, output: {len(all_outputs[k])}')
        intermediate_dataloaders[k] = DataLoader(mlp_dataset_k, batch_size=batch_size, shuffle= not is_val)
        
    return intermediate_dataloaders

def generate_random_dataloader(
    dataset_type,
    
    random_size,
    dim,
    
    batch_size,
    batch_size_val,
    
    store_data_dir,
    input_layer,
):    
    # generate random data, and return train/eval dataloader for mlp warmup
    
    if dataset_type == 0:
        # generate random data
        random_data = torch.randn(random_size, dim) * 0.1
    
    elif dataset_type == 1:
        
        # make sure that Z^T Z is approximately I
        # Z = torch.randn(random_size, dim) / 100 / math.sqrt(random_size)  # also 0.29 -> 0.028! so maybe not need to be so complicated as orthogonal matrix
        Z = torch.randn(random_size, dim) / 500 #(random_size // 1000)
        # Z = generate_random_orthogonal_matrix(random_size, dim) # 0.29 -> 0.028
        
        if (1 == 1) and os.path.exists(os.path.join(store_data_dir, f'S_complete_{input_layer}.pt')):
            Vt = torch.load(os.path.join(store_data_dir, f'Vt_complete_{input_layer}.pt'))
            S = torch.load(os.path.join(store_data_dir, f'S_complete_{input_layer}.pt'))
            print(f'read Vt and S from {store_data_dir}, Vt size: {Vt.size()}, S size: {S.size()}')
        else:
            # generate random data that has the similar covariance as hook data
            # target_data = torch.load(f'xxx/llama_reader/train_inputs_{input_layer}.pt').view(-1, dim)
            target_data = torch.load(f'xxx/llama_reader/train_inputs_complete_{input_layer}.pt').view(-1, dim)
            print_debug(f'== target_data: {target_data.size()}') # target data is B x max_len x dim, we need to reshape it to (B*max_len) x dim
        
            ######################################## original code ########################################
            # # don't use cuda to calculate the svd, error is larger!!
            # U, S, Vt = torch.linalg.svd(target_data, full_matrices=False)
            
            # print_debug(f'dist: {torch.norm(target_data - U @ torch.diag(S) @ Vt, p=2)}')
            # print_debug(f'dist between cov {torch.dist(target_data.t() @ target_data, Vt.t() @ torch.diag(S**2) @ Vt)}')
            # # dist: 0.0009, dist between cov 0.02        
            # # # print_debug(f'dist between covariance matrix: {torch.dist(random_data.t() @ random_data, target_data.t() @ target_data)}') # 0.02 if use generate_random_orthogonal_matrix. 273 if use torch.randn
            
            ######################################## modified code ########################################
            
            # # since N maybe too large, we can decompose using target_data.t() @ target_data
            Cov = target_data.t() @ target_data
            # print_debug(f'Cov size: {Cov.size()}')
            # V, S2, Vh = torch.linalg.svd(Cov)
            # print_debug(f'V size: {V.size()}, S2 size: {S2.size()}, dist: {torch.norm(Cov - V @ torch.diag(S2) @ Vh, p=2)}, V, Vh dist: {torch.norm(V - Vh.t(), p=2)}')
            # # dist: 8.38, V, Vh dist: 0.002
            
            # Vt = (Vh + V.t()) / 2
            # S = torch.sqrt(S2)
            
            Cov_numpy = Cov.cpu().numpy()
            print_debug(f'Cov_numpy size: {Cov_numpy.shape}')
            eigvals, eigvecs = np.linalg.eigh(Cov_numpy)
            print_debug(f'eigvecs {eigvecs.shape}, dist between cov {np.linalg.norm(Cov_numpy - eigvecs @ np.diag(eigvals) @ eigvecs.T)}, eigvecs @ eigvecs.T dist: {np.linalg.norm(eigvecs @ eigvecs.T - np.eye(dim))}')
            # 0.07, 2e-5
            Vt = torch.tensor(eigvecs, dtype=torch.float32).t()
            S = torch.sqrt(torch.tensor(eigvals, dtype=torch.float32))
            
            # save Vt and S
            torch.save(Vt, os.path.join(store_data_dir, f'Vt_complete_{input_layer}.pt'))
            torch.save(S, os.path.join(store_data_dir, f'S_complete_{input_layer}.pt'))
            
            ########################################################
            # K = target_data.numpy()
            # Sigma = np.cov(K, rowvar=False)
            # eigvals, eigvecs = np.linalg.eigh(Sigma)
            # print_debug(f'dist between cov {np.linalg.norm(Sigma - eigvecs @ np.diag(eigvals) @ eigvecs.T)}')
            
            # S = torch.sqrt(torch.tensor(eigvals, dtype=torch.float32))
            # Vt = torch.tensor(eigvecs, dtype=torch.float32).t()

            print(f"calculate Vt and S, Vt size: {Vt.size()}, S size: {S.size()}, Vt save to {os.path.join(store_data_dir, f'Vt_complete_{input_layer}.pt')}")
            
        print_debug(f'S[:5]: {S[:5]}, S[-5:]: {S[-5:]}')
        random_data = Z @ torch.diag(S) @ Vt
        # visualize_singular_values(S, pic_path, 'log_singular_values_teacher_input.png')
        
        
    elif dataset_type == 2: # read the ground truth data
        # random_data = torch.load(f'xxx/llama_reader/train_inputs_{input_layer}.pt').view(-1, dim)
        random_data = torch.load(os.path.join(store_data_dir, f'train_inputs_complete_{input_layer}.pt')).view(-1, dim)
        print_debug(f'== random_data: {random_data.size()}')
    
    length = random_data.size(0)
    # split the data into train and eval
    train_size = int(length * 0.95)
    
    print(f'train_size: {train_size}, eval_size: {length - train_size}, random_data size: {random_data.size()}')
    
    train_loader = DataLoader(random_data[:train_size, :], batch_size=batch_size, shuffle=True)
    eval_loader = DataLoader(random_data[train_size:, :], batch_size=batch_size_val, shuffle=False)
    
    print(f'train_loader size: {len(train_loader)}, eval_loader size: {len(eval_loader)}')
    
    return train_loader, eval_loader
            
    