import os
import random
import torch
import torch.distributed as dist
import re
import numpy as np
from datasets import load_dataset
import json
from datasets import Dataset

PROMPT_USER: str = 'User: {input}\n\n'
PROMPT_ASSISTANT: str = 'Assistant:' 
ASSISTANT_RESPONSE: str = ' {input}'

def get_formatted_question(line):
    return PROMPT_USER.format(input=str(line).strip()) + PROMPT_ASSISTANT

def get_formatted_answer(line):
    return ASSISTANT_RESPONSE.format(input=str(line).strip())

def get_formatted_input_and_target(messages, tokenizer, IGNORE_TOKEN_ID=-100, mask_prompt=True, train_mode=True):
    input_ids = []
    target_ids = []
    if tokenizer.bos_token_id is None:
        if getattr(tokenizer, "cls_token_id", None) is not None:
            tokenizer.bos_token = tokenizer.cls_token
        elif getattr(tokenizer, "sep_token_id", None) is not None:
            tokenizer.bos_token = tokenizer.sep_token
        else:
            tokenizer.add_special_tokens({"bos_token": "<s>"})
    for idx, message in enumerate(messages):
        if idx == 0:
            input_ids.extend([tokenizer.bos_token_id])
            target_ids.extend([tokenizer.bos_token_id])

        if message['role'] == "user":
            formatted_question = get_formatted_question(message['content'])
            tokenized_line = tokenizer.encode(formatted_question, add_special_tokens=False)
            input_ids.extend(tokenized_line)
            if mask_prompt:
                target_ids.extend([IGNORE_TOKEN_ID] * len(tokenized_line))
            else:
                target_ids.extend(tokenized_line)
        elif message['role'] == "assistant" and train_mode:
            formatted_answer = get_formatted_answer(message['content'])
            tokenized_line = tokenizer.encode(formatted_answer, add_special_tokens=False) + [tokenizer.eos_token_id]
            input_ids.extend(tokenized_line)
            if message.get('mask', 0) == 1:
                target_ids.extend([IGNORE_TOKEN_ID] * len(tokenized_line))
            else:
                target_ids.extend(tokenized_line)
        else:
            assert False, f"Unknown role: {message['role']}"

    return [input_ids, target_ids]


def get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio, IGNORE_TOKEN_ID=-100):
    all_input_ids_list, all_target_ids_list = [], []
    all_input_ids, all_target_ids = [], []

    for input_ids, target_ids in buffer:
        if len(input_ids) > seq_length - len(all_input_ids):
            input_ids = input_ids[-(seq_length - len(all_input_ids)):]
            target_ids = target_ids[-(seq_length - len(all_target_ids)):]
        if len(all_input_ids) > 0 and random.random() < random_concat_ratio:
            input_ids = input_ids[1:]
            target_ids = target_ids[1:]
        all_input_ids.extend(input_ids)
        all_target_ids.extend(target_ids)
        if len(all_input_ids) >= seq_length:
            assert len(all_input_ids) == seq_length, f"{len(all_input_ids)=}, {seq_length=}, {len(buffer)=}"
            all_input_ids_list.append(all_input_ids)
            all_target_ids_list.append(all_target_ids)
            all_input_ids, all_target_ids = [], []
    all_input_ids = all_input_ids + [tokenizer.pad_token_id for i in range(seq_length - len(all_input_ids))]
    all_target_ids = all_target_ids + [IGNORE_TOKEN_ID for i in range(seq_length - len(all_target_ids))]
    all_input_ids_list.append(all_input_ids)
    all_target_ids_list.append(all_target_ids)
    if len(all_input_ids) <= 0:
        return None
    return {
        "input_ids": torch.tensor(all_input_ids_list, dtype=torch.long),
        "labels": torch.tensor(all_target_ids_list, dtype=torch.long)
    }


def init_parallel_groups(ep_size=1):
    dist.init_process_group("nccl")
    world_size = int(os.getenv("WORLD_SIZE", "0"))
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
    torch.cuda.set_device(local_rank)
    ep_group = edp_group = None
    for i in range(0, world_size, ep_size):
        ranks = list(range(i, i + ep_size))
        group = dist.new_group(ranks)
        if local_rank in ranks:
            ep_group = group
    edp_group = None
    for i in range(ep_size):
        ranks = list(range(i, world_size, ep_size))
        group = dist.new_group(ranks)
        if local_rank in ranks:
            edp_group = group
    dist.all_reduce(torch.zeros(1, device="cuda"), group=ep_group)
    dist.all_reduce(torch.zeros(1, device="cuda"), group=edp_group)
    return world_size, local_rank, ep_group, edp_group
import re
from typing import Optional

BOOL_TRUE = ("true", "yes", "y")
BOOL_FALSE = ("false", "no", "n")

def _norm(s: str) -> str:
    return re.sub(r"\s+", " ", s.strip())

def _pick_bool(text: str) -> Optional[str]:
    s = text.lower()
    m = re.search(r"(?:answer|label)\s*[:=]\s*(true|false|yes|no)\b", s)
    if m:
        return "True" if m.group(1) in BOOL_TRUE else "False"
    if re.search(r"\b(true|yes)\b", s):
        return "True"
    if re.search(r"\b(false|no)\b", s):
        return "False"
    return None

def _pick_letter(text: str) -> Optional[str]:
    m = re.search(r"(?i)(?:answer|option)\s*[:=]?\s*\(?\s*([a-e])\s*\)?", text)
    if m:
        return m.group(1).upper()
    m2 = re.search(r"(?<![A-Za-z])([A-E])(?![A-Za-z])", text)
    if m2:
        return m2.group(1).upper()
    return None

def _pick_number(text: str) -> Optional[str]:
    m = re.search(r"\\boxed\{\s*([^}]+)\s*\}", text)
    if m:
        inside = m.group(1).strip()
        nums = re.findall(r"-?\d+(?:\.\d+)?", inside)
        return nums[-1] if nums else inside
    nums = re.findall(r"-?\d+(?:\.\d+)?", text.replace(",", ""))
    return nums[-1] if nums else None

def _pick_number_from_text_math(s: str):
    if s is None: return None
    m = re.search(r"####\s*([-+]?\d[\d,]*\.?\d*)", s)
    if not m: m = re.search(r"The answer is\s*([-+]?\d[\d,]*\.?\d*)", s)
    if not m: m = re.search(r"([-+]?\d[\d,]*\.?\d*)\s*$", s)
    if m:
        num_str = m.group(1).replace(",", "")
        try: return int(float(num_str))
        except Exception: return None
    return None

def extract_answer(dataset_name, sentence: str) -> str:
    if not sentence:
        return ""

    dataset = (dataset_name or "").lower().replace("-", "_")
    s = _norm(sentence)

    dispatch = {
        "metamathqa": [_pick_number_from_text_math],
        "gsm8k": [_pick_number_from_text_math],
    }
    
    if dataset not in dispatch:
        raise NotImplementedError(f"Dataset {dataset} not supported")

    for fn in dispatch[dataset]:
        try:
            ans = fn(s)
        except Exception:
            ans = None
            print(f"Error applying {fn} on text: {s}")
        if ans is not None and str(ans).strip() != "":
            return str(ans).strip()

    return ""

def set_seed(seed: int = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

def filter_none_dialogue(example):
    return example["dialogue"] is not None

def convert_gsm8k_to_messages(example):
    return {
        'messages': [
            {'role': 'user', 'content': example['question']},
            {'role': 'assistant', 'content': example['answer']}
        ]
    }

def convert_humaneval_to_messages(example):
    return {
        'messages': [
            {'role': 'user', 'content': f"Complete the following Python function:\n\n{example['prompt']}\n"},
            {'role': 'assistant', 'content': example['canonical_solution']}
        ]
    }


def convert_math_to_messages(example):
    return {
        'messages': [
            {'role': 'user', 'content': example['problem']},
            {'role': 'assistant', 'content': example['solution']}
        ]
    }
def convert_metamathqa_to_messages(example):
    return {
        'messages': [
            {'role': 'user', 'content': example['query']},
            {'role': 'assistant', 'content': example['response']}
        ]
    }

def convert_prefeval_to_messages(example):
    system_prompt = "You are an AI assistant."
    preference = example["preference"]
    question = example["question"]
    pref_generation = example["response_to_pref"]
    assistant_content = example["response_to_q"]
    user_content = f"""<s>[INST]{system_prompt}{preference}[/INST]{pref_generation}</s>[INST]{question}</s>"""
    return {
        'messages': [
            {'role': 'user', 'content': user_content},
            {'role': 'assistant', 'content': assistant_content}
        ]
    }

def convert_mbpp_to_messages(example):
    test = "\n".join([str(elem) for elem in example['test_list']])
    user_content = f"Complete the following Python function:\n\n{example['text']}\n\nThe function should pass the following test cases:\n{test}\n\n"
    assistant_content = example['code']
    return {
        'messages': [
            {'role': 'user', 'content': user_content},
            {'role': 'assistant', 'content': assistant_content}
        ]
    }
def convert_codealpaca_to_messages(example):
    if example.get('input', '').strip():
        user_content = f"{example['instruction']}\nInput: {example['input']}\n\n"
    else:
        user_content = f"{example['instruction']}\n\n"
    assistant_content = example['output']
    return {
        'messages': [
            {'role': 'user', 'content': user_content},
            {'role': 'assistant', 'content': assistant_content}
        ]
    }

def get_dataset_converter(dataset_name):
    converters = {
        'gsm8k': convert_gsm8k_to_messages,
        'humaneval': convert_humaneval_to_messages,
        'math': convert_math_to_messages,
        'metamathqa': convert_metamathqa_to_messages,
        'prefeval': convert_prefeval_to_messages,
        'mbpp': convert_mbpp_to_messages,
        'codealpaca': convert_codealpaca_to_messages,
    }
    if dataset_name not in converters:
        return lambda x: x
    return converters[dataset_name]

def load_hf_dataset_train(dataset_name, subset="main", split='train'):
    if dataset_name == "metamathqa":
        return load_dataset("meta-math/MetaMathQA-40K", split=split)
    elif dataset_name == "prefeval":
        current_dir = os.path.dirname(os.path.abspath(__file__))
        file_path = os.path.join(current_dir, "benchmarks", "prefeval", "train.json")
        with open(file_path, "r", encoding="utf-8") as f:
            samples = json.load(f)   
        return Dataset.from_list(samples)
    elif dataset_name == "codealpaca":
        return load_dataset("sahil2801/CodeAlpaca-20k", split=split)
    else:
        dataset = load_dataset(dataset_name, split=split)
    return dataset

def load_hf_dataset_test(dataset_name, subset="main", split='test'):
    if dataset_name == "gsm8k":
        return  load_dataset(dataset_name, subset, split=split)
    elif dataset_name == "math":
        return load_dataset("EleutherAI/hendrycks_math", "algebra", split=split)
    elif dataset_name == "prefeval":
        current_dir = os.path.dirname(os.path.abspath(__file__))
        file_path = os.path.join(current_dir, "benchmarks", "prefeval", "test.json")
        with open(file_path, "r", encoding="utf-8") as f:
            samples = json.load(f)   
        return Dataset.from_list(samples)               
    elif dataset_name == "mbpp":
        return load_dataset("google-research-datasets/mbpp", "full", split=split)
    elif dataset_name == "humaneval":
        return load_dataset("openai/openai_humaneval", split=split)
    else:
        return load_dataset(dataset_name, split=split)


