import os
import gzip
import json
import random
import logging
import torch
import transformers
from datasets import load_dataset
from dataclasses import dataclass
from .llm_dataset import DefaultToken, LLMDataset
from .llm_utils import download_url
from collections import defaultdict

logger = logging.getLogger(__name__)


@dataclass
class LLMDataCollator(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances):
        input_ids, labels = tuple([instance[key] for instance in instances]
                                  for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(
            labels,
            batch_first=True,
            padding_value=DefaultToken.IGNORE_INDEX.value)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

def filter_data_by_category(list_data_dict, max_per_category=100):
    # Initialize a defaultdict to count entries per category
    category_counts = defaultdict(int)
    filtered_data = []

    # Iterate over each data point
    for item in list_data_dict:
        category = item['category']  # Assume each item has a 'category' key
        if category_counts[category] < max_per_category:
            filtered_data.append(item)
            category_counts[category] += 1

    return filtered_data

from collections import Counter
from typing import List, Dict, Any

def filter_top_k_categories(data: List[Dict[str, Any]], k: int) -> List[Dict[str, Any]]:
    """
    Filters the dataset to include only the top k categories by the number of data points.

    Parameters:
    data (List[Dict[str, Any]]): The dataset, a list of dictionaries where each dictionary represents a data point.
    k (int): The number of top categories to filter by.

    Returns:
    List[Dict[str, Any]]: A filtered list of dictionaries containing only data points from the top k categories.
    """
    # Step 1: Count the number of data points for each category
    category_counts = Counter(item['category'] for item in data)

    # Step 2: Get the top k categories with the most data points
    top_k_categories = category_counts.most_common(k)

    # Step 3: Filter the dataset to only include data points from the top k categories
    filtered_data = [item for item in data if item['category'] in dict(top_k_categories).keys()]

    # (Optional) Print the top k categories and their counts
    print(f"Top {k} categories by data points:")
    for category, count in top_k_categories:
        print(f"Category {category}: {count} data points")

    return filtered_data

def get_tokenizer(model_name, cache_dir, tok_len=128, token = None):
    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        cache_dir=cache_dir,
        model_max_length=tok_len,
        padding_side="right",
        use_fast=False,
        token = token,
    )

    special_tokens = dict()
    if tokenizer.pad_token is None:
        special_tokens["pad_token"] = DefaultToken.PAD_TOKEN.value
    if tokenizer.eos_token is None:
        special_tokens["eos_token"] = DefaultToken.EOS_TOKEN.value
    if tokenizer.bos_token is None:
        special_tokens["bos_token"] = DefaultToken.BOS_TOKEN.value
    if tokenizer.unk_token is None:
        special_tokens["unk_token"] = DefaultToken.UNK_TOKEN.value

    # num_new_tokens = tokenizer.add_special_tokens(special_tokens)
    num_new_tokens = None
    tokenizer.pad_token = tokenizer.eos_token
    return tokenizer, num_new_tokens


def load_json(file_path,
              instruction='instruction',
              input='input',
              output='output',
              category='category'):
    # Format: [{'instruction': ..., 'input': ..., 'output':...}]
    with open(file_path, 'r', encoding="utf-8") as f:
        list_data_dict = json.load(f)

    # Replace key
    new_list_data_dict = []
    for item in list_data_dict:
        new_item = dict(
            instruction=item[instruction] if instruction in item else None,
            input=item[input] if input in item else None,
            output=item[output] if output in item else None,
            category=item[category] if category in item else None)
        new_list_data_dict.append(new_item)
    return new_list_data_dict


def load_jsonl(file_path,
               instruction='instruction',
               input='input',
               output='output',
               category='category',
               is_gzip=False):
    # Format of each line:
    # {'instruction': ..., 'input': ..., 'output':...}
    list_data_dict = []
    open_func = open if not is_gzip else gzip.open
    with open_func(file_path, 'r') as f:
        for line in f:
            item = json.loads(line)
            new_item = dict(
                instruction=item[instruction] if instruction in item else None,
                input=item[input] if input in item else None,
                output=item[output] if output in item else None,
                category=item[category] if category in item else None)
            item = new_item
            list_data_dict.append(item)
    return list_data_dict

def split_text_into_half(text):
    mid_index = len(text) // 2  # Find the middle index
    # Find the nearest space to avoid breaking words
    # This searches for the first space to the right of the middle.
    space_index = text.find(' ', mid_index)
    if space_index == -1:  # If no space is found, try left
        space_index = text.rfind(' ', 0, mid_index)
    if space_index != -1:  # If a space is found, adjust split point
        mid_index = space_index

    first_half = text[:mid_index].strip()
    second_half = text[mid_index:].strip()
    return first_half, second_half


def transform_dataset(list_data_dict): #data preprocessing for last-token prediction
    new_list_data_dict = []
    for item in list_data_dict:
        # Split the instruction field text
        words = item['instruction'].split()
        first_part = ' '.join(words[:-1])  # Join all but the last word
        last_word = words[-1] if words else ''  # Get the last word, handle empty case

        # Update the dictionary
        item['instruction'] = first_part
        item['output'] = last_word  # assuming you want to overwrite the existing output field
        new_list_data_dict.append(item)
    return new_list_data_dict




def load_llm_dataset(config=None, **kwargs):

    _, model_name = config.model.type.split(':')
    tokenizer, num_new_tokens = \
        get_tokenizer(model_name, config.data.root, config.llm.tok_len, config.huggingface.token)
    
    dataset_name, _ = config.data.name.split('@')

    
    # Assuming dataset_name could be a Hugging Face dataset identifier
    if 'huggingface:' in dataset_name:
        hf_dataset_name = dataset_name.split('huggingface:')[1]
        
        if hf_dataset_name == 'tyqiangz/multilingual-sentiments':
            dataset = load_dataset(hf_dataset_name, 'all', split='train')  # Change split as necessary
            # Transform Hugging Face dataset format to list of dicts
            list_data_dict = [{
                'instruction': item['text'],  # Specify the correct field name
                'input': '',             # Specify the correct field name
                'output': item['language'],           # Specify the correct field name
                'category': item.get('language')  # Adjust based on dataset or omit if not applicable
            } for item in dataset]
            # list_data_dict = transform_dataset(list_data_dict)
            # desired_categories = ['indonesian', 'french', 'spanish','portuguese', 'arabic','chinese','german','italian']
            # desired_categories = ['indonesian', 'french', 'spanish','portuguese']
            # list_data_dict = [item for item in list_data_dict if item['category'] in desired_categories]
        elif hf_dataset_name == 'papluca/language-identification':
            dataset = load_dataset(hf_dataset_name, split='train')  # Change split as necessary
            # Transform Hugging Face dataset format to list of dicts
            list_data_dict = [{
                'instruction': item['text'],  # Specify the correct field name
                'input': '',             # Specify the correct field name
                'output': item['labels'],           # Specify the correct field name
                'category': item.get('labels')  # Adjust based on dataset or omit if not applicable
            } for item in dataset]
            # list_data_dict = transform_dataset(list_data_dict) 
            # desired_categories = ['ar','bg','de', 'el','es','fr','hi','nl']
            # list_data_dict = [item for item in list_data_dict if item['category'] in desired_categories]
        elif hf_dataset_name == 'seara/ru_go_emotions':
            dataset = load_dataset(hf_dataset_name, 'simplified', split='train')
            # Transform Hugging Face dataset format to list of dicts
            list_data_dict = [{
                'instruction': item['text'],  # Specify the correct field name
                'input': '',             # Specify the correct field name
                'output': item['labels'][0],           # Specify the correct field name
                'category': item.get('labels')[0]  # Adjust based on dataset or omit if not applicable
            } for item in dataset]
            list_data_dict = filter_top_k_categories(list_data_dict, k=12)

            # list_data_dict = transform_dataset(list_data_dict)
            # desired_categories = [1,2,3,4,5,6,7,8]
            # list_data_dict = [item for item in list_data_dict if item['category'] in desired_categories]
        filtered_list_data_dict = filter_data_by_category(list_data_dict, max_per_category=5000)
        # Tokenize and prepare data in the expected format
        dataset = LLMDataset(filtered_list_data_dict, tokenizer)
        return (dataset, tokenizer)

    if dataset_name.endswith('.json'):

        fp = os.path.join(config.data.root, dataset_name)
        list_data_dict = load_json(fp)
        dataset = LLMDataset(list_data_dict, tokenizer)
    elif dataset_name.endswith('.jsonl'):
        fp = os.path.join(config.data.root, dataset_name)
        list_data_dict = load_jsonl(fp)
        dataset = LLMDataset(list_data_dict, tokenizer)
    elif dataset_name.lower() == 'alpaca':
        fp = os.path.join(config.data.root, 'alpaca_data.json')
        download_url(
            'https://raw.githubusercontent.com/tatsu-lab'
            '/stanford_alpaca/'
            '761dc5bfbdeeffa89b8bff5d038781a4055f796a/'
            'alpaca_data.json', config.data.root)
        list_data_dict = load_json(fp)
        dataset = LLMDataset(list_data_dict, tokenizer)
    elif dataset_name.lower() == 'alpaca_cleaned':
        fp = os.path.join(config.data.root, 'alpaca_data_cleaned.json')
        download_url(
            'https://raw.githubusercontent.com/gururise/AlpacaDataCleaned/'
            'a7d629079a95c2e4b7ec7dfe55087fbd18d9eba8/'
            'alpaca_data_cleaned.json', config.data.root)
        list_data_dict = load_json(fp)
        dataset = LLMDataset(list_data_dict, tokenizer)
    elif dataset_name.lower() == 'dolly-15k':
        fp = os.path.join(config.data.root, 'databricks-dolly-15k.jsonl')
        download_url(
            'https://raw.githubusercontent.com/databrickslabs'
            '/dolly/d000e3030970379aabbf6d291f50ffdd3b715b64'
            '/data/databricks-dolly-15k.jsonl', config.data.root)
        list_data_dict = load_jsonl(fp,
                                    instruction='instruction',
                                    input='context',
                                    output='response',
                                    category='category')
        dataset = LLMDataset(list_data_dict, tokenizer)
    elif dataset_name.lower() == 'gsm8k':
        fp = os.path.join(config.data.root, 'gsm8k_train.jsonl')
        if not os.path.exists(fp):
            download_url(
                'https://raw.githubusercontent.com/openai/grade-school-math'
                '/3101c7d5072418e28b9008a6636bde82a006892c/'
                'grade_school_math/data/train.jsonl', config.data.root)
            os.rename(os.path.join(config.data.root, 'train.jsonl'), fp)
        list_data_dict = load_jsonl(fp,
                                    instruction='question',
                                    output='answer')
        for i in range(len(list_data_dict)):
            list_data_dict[i]['output'] = \
                list_data_dict[i]['output'].replace('####', 'The answer is')
        dataset = LLMDataset(list_data_dict, tokenizer)
    elif dataset_name.lower() == 'code_search_net':
        from tqdm import tqdm
        from code_search_net import \
            CSN_FILE_NUM_DICT

        list_data_dict = []
        logger.info('Loading code search net data file...')
        try:
            for language in tqdm(CSN_FILE_NUM_DICT.keys()):
                sub_list_data_dict = []
                for file_index in range(CSN_FILE_NUM_DICT[language]['train']):
                    fp = \
                        os.path.join(config.data.root, language,
                                     'final', 'jsonl', 'train',
                                     f'{language}_train_{file_index}.jsonl.gz')
                    tmp_list_data_dict = load_jsonl(
                        fp,
                        instruction='docstring',
                        input='language',
                        output='code',
                        category='language',
                        is_gzip=True,
                    )
                    sub_list_data_dict += tmp_list_data_dict
                # Subsample
                raw_size = len(sub_list_data_dict)
                num_subsample = int(raw_size * config.data.subsample)
                list_data_dict += random.sample(sub_list_data_dict,
                                                num_subsample)
                logger.info(f"Subsample "
                            f"{sub_list_data_dict[0]['category']} with "
                            f"rate {config.data.subsample}: "
                            f"the sample size is # {num_subsample} "
                            f"(the raw size is {raw_size}).")
            # Modify instruction with specific language
            for sample in list_data_dict:
                sample['instruction'] = \
                    sample['category'] + ' ' + sample['instruction']
        except FileNotFoundError:
            raise FileNotFoundError(
                'Data not found! Please run `python '
                'federatedscope/llm/dataset/code_search_net.py` '
                'to download data.')
        dataset = LLMDataset(list_data_dict, tokenizer)
    elif dataset_name.lower() == 'rosetta_alpaca':
        fp = os.path.join(config.data.root, 'rosetta_alpaca.json')
        download_url(
            'https://raw.githubusercontent.com/'
            'sahil280114/codealpaca/'
            'd269da106a579a623a654529b3cb91b5dfa9c72f/'
            'data/rosetta_alpaca.json', config.data.root)
        list_data_dict = load_json(fp,
                                   instruction='instruction',
                                   input='input',
                                   output='output',
                                   category='input')
        # Remove 'x86-64 Assembl' if splitter is `meta` due to the number of
        # samples is too small.
        if config.data.splitter == 'meta':
            list_data_dict = [
                i for i in list_data_dict if i['category'] != 'X86-64 Assembly'
            ]
        dataset = LLMDataset(list_data_dict, tokenizer)
    else:
        raise ValueError(f'Not support data type {dataset_name}.')

    return (dataset, tokenizer)
