import yaml
from collections import Counter
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from .deepspeed import DeepspeedStrategy
import os
import torch
from datasets import interleave_datasets, load_dataset, load_from_disk


## For Data Sampling and Adapter Training using accelerator
EXCEPTION_LIST = ['Mr', 'Mrs', 'Ms', 'Dr', 'Sr', 'Jr', 'Prof', 'St', 'Mt', 'Ft', 'No', 'no', 'etc', \
    'i.e', 'e.g', 'cf', 'cf.', 'Fig', 'fig', 'Figs', 'figs', 'Vol', 'vol', 'Vols', 'vols', \
    'Ch', 'ch', 'Sec', 'sec', 'Secs', 'secs', 'Eq', 'eq', 'Eqs', 'eqs', 'Fig', 'fig', 'Figs', 'figs', \
    'Ref', 'ref', 'Refs', 'refs', 'App', 'app', 'Apps', 'apps', 'Jan', 'Feb', 'Mar', 'Apr', 'Jun', 'Jul', 'Aug', 'Sep', 'Sept', 'Oct', 'Nov', 'Dec', \
    'vs', 'Vs', 'etc', 'i.e', 'e.g', 'U.S', 'U.K', 'U.A.E', 'P.R.C', 'U.S.A', 'J.R.R']

def load_config(config_path):
    with open(config_path, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config

def extract_sentences(text_list, search_step=1):
    """Extract the correct format sentence from a list of text"""
    sentences = []
    
    for text in text_list:
        texts = text.strip().split('\n')
        filtered_texts = [t for t in texts if 
                  ((match := re.match(r'^(\d+)\.', t)) and int(match.group(1)) > search_step) or 'answer is' in t]
        sentences.append('\n'.join(filtered_texts))
    
    return sentences
        

def extract_first_sentences(text_list, search_step=1, gsm8k=True, finish_reasons=None, ensure_length=2048):
    """Extracts the first sentence from a list of text."""
    first_sentences = []
    delimeter = '\n\n' if not gsm8k else '\n'
    assert len(text_list) == len(finish_reasons), "The length of text_list and finish_reasons should be the same."
    
    for text, finish_reason in zip(text_list, finish_reasons):
        # Splitting at each period
        texts = text.strip().split(delimeter)
        if gsm8k:
            filtered_texts = [t for t in texts if 
                    ((match := re.match(r'^(\d+)\.', t)) and int(match.group(1)) > search_step) or 'answer is' in t]
            try:
                text = filtered_texts[0].strip()
            except IndexError:
                continue
            # Simply filter out some incomplete texts
            if text.isdigit() or (text.endswith('.') and text[:-1].isdigit()):
                continue
            parts = text.split('. ')
            first_sentence = ''
            for part in parts:
                if not first_sentence:
                    first_sentence = part
                else:
                    # Check if the previous part ended with an abbreviation or a number
                    if any(first_sentence.endswith(abbreviation) for abbreviation in EXCEPTION_LIST):
                        first_sentence += '. ' + part
                    elif first_sentence[-1].isdigit():
                        first_sentence += '. ' + part
                    else:
                        break
            first_sentence += '.' if not first_sentence.strip().endswith('.') else ''
        else:
            if finish_reason == 'stop':
                first_sentence = text.strip() + '<|im_end|>'
            elif finish_reason == 'length':
                if len(texts[0]) > ensure_length: # ensure each step should be at least 128 tokens
                    first_sentence = texts[0]
                else:
                    # find the first aggregated sentence that is longer than ensure_length
                    cur_length = 0
                    for i in range(1, len(texts)):
                        cur_length += len(texts[i])
                        if cur_length > ensure_length:
                            first_sentence = delimeter.join(texts[:i])
                            break
            else:
                raise NotImplementedError(f"Finish reason {finish_reason} is not implemented.")
        
        first_sentences.append(first_sentence)

    return first_sentences

def deduplication(text_list, num_to_keep, fill_to, semantic_model: AutoModelForSequenceClassification=None, semantic_tokenizer: AutoTokenizer=None, context: str=None):
    if len(text_list) == 0:
        raise ValueError("The list of texts is empty.")
    # Count the frequency of each item
    freq_count = Counter(text_list)

    # Create a set of unique items
    unique_items = list(set(text_list))

    # Sort the unique items by their frequency in descending order
    unique_items.sort(key=lambda x: -freq_count[x])

    # Use Semantic Model to filter entailment sentence
    if semantic_model:
        debug_dict = {}
        if len(unique_items) > 1:
            tmp = []
            unique_flag = [True] * len(unique_items) 
            for i in range(len(unique_items)):
                for j in range(i+1, len(unique_items)):
                    # qa_1 = context + '\n' + unique_items[i]
                    # qa_2 = context + '\n' + unique_items[j]
                    qa_1 = unique_items[i]
                    qa_2 = unique_items[j]

                    input = qa_1 + ' [SEP] ' + qa_2
                    encoded_input = semantic_tokenizer(input, padding=True, return_tensors='pt')
                    encoded_input = {key: tensor.to('cuda') for key, tensor in encoded_input.items()}
                    prediction = semantic_model(**encoded_input)['logits']
                    prediction = prediction.argmax(dim=1).item()

                    reverse_input = qa_2 + ' [SEP] ' + qa_1
                    encoded_reverse_input = semantic_tokenizer(reverse_input, padding=True, return_tensors='pt')
                    encoded_reverse_input = {key: tensor.cuda() for key, tensor in encoded_reverse_input.items()}
                    reverse_prediction = semantic_model(**encoded_reverse_input)['logits']
                    reverse_prediction = reverse_prediction.argmax(dim=1).item()
                    debug_dict[(i, j)] = (prediction, reverse_prediction)

                    # if both are entailment, keep the one with higher frequency
                    if prediction == 2 and reverse_prediction == 2:
                        unique_flag[j] = False
            for i in range(len(unique_items)):
                if unique_flag[i]:
                    tmp.append(unique_items[i])
            unique_items = tmp                   

    # Append items to make the list of at least 'num_to_keep' length, prioritizing higher frequency items
    while len(unique_items) < num_to_keep:
        for item in unique_items:
            if len(unique_items) >= num_to_keep:
                break
            unique_items.append(item)

    # fill to the end of the list with "<EMPTY>"
    while len(unique_items) < fill_to:
        unique_items.append("<EMPTY>")
    return unique_items

def accumulate_strings(string_list, gsm8k=True):
    accumulated_list = []
    delimiter = '\n\n' if not gsm8k else '\n'

    for s in string_list:
        lines = s.split(delimiter)
        if gsm8k:
            lines = [line.strip('\n').strip() for line in lines if line.strip('\n').strip() and (re.match(r'^(\d+)\.', line) or 'answer is' in line)]
        else:
            lines = [line.strip('\n').strip() for line in lines if line.strip('\n').strip()]
        accumulated = []
        current_accumulation = ""
        for line in lines:
            current_accumulation += line if not current_accumulation else delimiter + line
            accumulated.append(current_accumulation)
        accumulated_list.extend(accumulated)

    return accumulated_list

def get_answer_start_from(input_ids, pattern):
    long_tensor = input_ids.detach().cpu()
    pattern_tensor = pattern.detach().cpu()
    for i in range(len(long_tensor)- len(pattern_tensor)):
        if torch.all(long_tensor[i:i+len(pattern_tensor)] == pattern_tensor):
            return i + len(pattern_tensor)
    return 0

def formulate_string(s):
    lines = s.split('\n')
    lines = [line.strip('\n').strip() for line in lines if line.strip('\n').strip() and (re.match(r'^(\d+)\.', line) or 'answer is' in line)]
    current_accumulation = ""
    for line in lines:
        current_accumulation += line if not current_accumulation else '\n' + line
    return current_accumulation

## For Proposal DPO Training using deepspeed
def get_tokenizer(pretrain, model, padding_side="left", use_fast=True, token=None):
    tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast, token=token)
    tokenizer.padding_side = padding_side
    # NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM.
    # https://github.com/facebookresearch/llama-recipes/pull/196
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id

    return tokenizer

def get_strategy(args):
    strategy = DeepspeedStrategy(
        seed=getattr(args, "seed", 42),
        max_norm=getattr(args, "max_norm", 1.0),
        micro_train_batch_size=getattr(args, "micro_train_batch_size", 1),
        train_batch_size=getattr(args, "train_batch_size", 128),
        zero_stage=args.zero_stage,
        bf16=getattr(args, "bf16", True),
        args=args,
    )
    return strategy

def blending_datasets(
    datasets,
    probabilities,
    strategy=None,
    seed=42,
    max_count=5000000,
    return_eval=True,
    stopping_strategy="first_exhausted",
    train_split="train",
    eval_split="test",
):
    datasets = datasets.split(",")
    probabilities = list(map(float, probabilities.split(",")))
    assert len(probabilities) == len(datasets)

    train_data_list = []
    eval_data_list = []
    for i, dataset in enumerate(datasets):
        dataset = dataset.strip()
        strategy.print(f"dataset: {dataset}")

        data_dir = dataset.split("@")[1].strip() if "@" in dataset else None
        dataset = dataset.split("@")[0].strip()
        dataset_basename = os.path.basename(dataset)

        ext = os.path.splitext(dataset)[-1]
        # local python script
        if ext == ".py" or (
            os.path.isdir(dataset) and os.path.exists(os.path.join(dataset, f"{dataset_basename}.py"))
        ):
            data = load_dataset(dataset, trust_remote_code=True)
            strategy.print(f"loaded {dataset} with python script")
        # local text file
        elif ext in [".json", ".jsonl", ".csv"]:
            ext = ext.lower().strip(".")
            if ext == "jsonl":
                ext = "json"
            data = load_dataset(ext, data_files=dataset)
            strategy.print(f"loaded {dataset} with data_files={dataset}")
        # local dataset saved with `datasets.Dataset.save_to_disk`
        elif os.path.isdir(dataset):
            data = load_from_disk(dataset)
            strategy.print(f"loaded {dataset} from disk")
        # remote/local folder or common file
        else:
            data = load_dataset(dataset, data_dir=data_dir)
            strategy.print(f"loaded {dataset} from files, data_dir={data_dir}")

        if train_split and train_split in data:
            train_data = data[train_split].select(range(min(max_count, len(data[train_split]))))
        else:
            train_data = data.select(range(min(max_count, len(data))))
        train_data_list.append(train_data)

        if return_eval:
            if eval_split and eval_split in data:
                eval_data = data[eval_split].select(range(min(max_count, len(data[eval_split]))))
            # train will contains eval? TODO
            else:
                eval_data = train_data.select(range(min(max_count, max(int(len(train_data) * 0.03),1))))
            eval_data_list.append(eval_data)

    # merge datasets
    if strategy.is_rank_0():
        print(train_data_list)

    train_dataset = interleave_datasets(
        train_data_list,
        probabilities=probabilities,
        seed=seed,
        stopping_strategy=stopping_strategy,
    )
    if return_eval:
        eval_dataset = interleave_datasets(
            eval_data_list,
            probabilities=probabilities,
            seed=seed,
            stopping_strategy=stopping_strategy,
        )
        return train_dataset, eval_dataset
    else:
        return train_dataset
