import json
import re
from tqdm import tqdm
import numpy as np
from functools import partial

import torch
from torch.utils.data import Dataset
from datasets import load_dataset
from transformers import AutoTokenizer


def build_dataset_confidence_rank(
    tokenizer, datapath, max_len=4096
):

    ds = load_dataset('json', data_files=datapath)
    ds = ds['train']
    ds = ds.shuffle(seed=42)
    ds1 = ds
    original_columns1 = ds1.column_names
    num_proc = 8

    def preprocess_function(examples):

        start_think_text = tokenizer.start_think_text
        stop_think_text = tokenizer.stop_think_text
        
        new_examples = {
            "attention_mask": [],
            "input_ids": [],
            "loss_mask": [],
            "confidence": []
        }

        for i in range(len(examples['id'])):
            conv_roles = ["user", "assistant"]
            roles = {
                "human": "user", 
                "user": "user",
                "gpt": "assistant",
                "assistant": "assistant"
            }
            source = examples['conversations'][i]
            if not source: continue

            sys_prompt = None
            if source[0]["from"] == "system":
                sys_prompt = source[0]["value"]
                if isinstance(sys_prompt, list):
                    sys_prompt = sys_prompt[0]["content"]
                source = source[1:]
            if not sys_prompt:
                sys_prompt = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
            messages = [{"role": "system", "content": sys_prompt}]

            if roles[source[0]["from"]] != "user" or roles[source[1]["from"]] != "assistant":
                # source = source[1:]
                continue

            question = source[0]["value"]
            answer = source[1]["value"]

            if isinstance(question, list):
                question = question[0]["content"]
            messages.append({"role": "user", "content": question})
            conversation = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
                # enable_thinking=True
            )
            if start_think_text in conversation:
                conversation = conversation.split(start_think_text)[0]

            input_ids = tokenizer.encode(conversation, add_special_tokens=False)
            if len(input_ids) > max_len - 100: continue
            loss_mask = [0] * len(input_ids)
            confidence = [0] * len(input_ids)
            progress = [0] * len(input_ids)
            remain = [0] * len(input_ids)
            last_confidence = None
            
            if isinstance(answer, list):
                if not answer or answer[0]["type"] != "think" or answer[-1]["type"] != "output":
                    continue
                if not answer[0]["content"].startswith(start_think_text):
                    answer[0]["content"] = start_think_text + "\n" + answer[0]["content"]
                if not answer[-2]["content"].endswith(stop_think_text):
                    answer[-2]["content"] += stop_think_text
                    
                output_content = ''.join([step["content"] for step in answer])
                if output_content.count(start_think_text) != 1 or \
                   output_content.count(stop_think_text) != 1 or \
                  (output_content.find(stop_think_text) - output_content.find(start_think_text)) <= 0:
                    continue
                answer_input_ids = []

                for j, step in enumerate(answer):
                    step_content = step["content"]

                    step_input_ids = tokenizer.encode(step_content, add_special_tokens=False)
                    input_ids.extend(step_input_ids)
                    answer_input_ids.extend(step_input_ids)
                    loss_mask.extend([1] * len(step_input_ids))

                    if last_confidence is None:
                        confidence.extend([step["confidence_score"]] * len(step_input_ids))
                    else:
                        confidence.extend(np.linspace(last_confidence, step["confidence_score"], len(step_input_ids)).tolist())
                    last_confidence = step["confidence_score"]

                if hasattr(tokenizer, "stop_think_id"):
                    stop_think_pos = answer_input_ids.index(tokenizer.stop_think_id)
                else:
                    stop_think_pos = answer_input_ids.index(tokenizer.stop_conv_id)
                progress.extend(np.linspace(0, 1, stop_think_pos + 1).tolist())
                progress.extend([1] * (len(input_ids) - len(progress)))
                remain.extend(list(range(stop_think_pos, -1, -1)))
                remain.extend([0] * (len(input_ids) - len(remain)))

                confidence = [[c, p, r] for c, p, r in zip(confidence, progress, remain)]

            elif isinstance(answer, str):
                if start_think_text not in answer:
                    answer = start_think_text + "\n" + answer
                if answer.count(start_think_text) != 1 or \
                   answer.count(stop_think_text) != 1 or \
                  (answer.find(stop_think_text) - answer.find(start_think_text)) <= 0:
                    continue
                answer_input_ids = tokenizer.encode(answer, add_special_tokens=False)
                input_ids.extend(answer_input_ids)
                loss_mask.extend([1] * len(answer_input_ids))

                if hasattr(tokenizer, "stop_think_id"):
                    stop_think_pos = answer_input_ids.index(tokenizer.stop_think_id)
                else:
                    stop_think_pos = answer_input_ids.index(tokenizer.stop_conv_id)
                confidence.extend(list(range(stop_think_pos, -1, -1)))
                confidence.extend([0] * (len(input_ids) - len(confidence)))

            if len(input_ids) > max_len:
                input_ids = input_ids[:max_len]
                loss_mask = loss_mask[:max_len]
                confidence = confidence[:max_len]
                
            if not any(loss_mask): continue

            input_ids = torch.tensor(input_ids)
            loss_mask = torch.tensor(loss_mask)
            confidence = torch.tensor(confidence)
            
            if hasattr(tokenizer, 'start_think_id') and tokenizer.start_think_id not in input_ids:
                print(f"There's no {start_think_text} in sentence id={examples['id'][i]} truncation_input={tokenizer.decode(input_ids)!r}")
                continue

            attention_mask = torch.ones_like(loss_mask)

            new_examples["input_ids"].append(input_ids[None])
            new_examples["loss_mask"].append(loss_mask[None])
            new_examples["attention_mask"].append(attention_mask[None])
            new_examples["confidence"].append(confidence[None])

        return new_examples

    ds1 = ds1.map(
        preprocess_function,
        batched=True,
        num_proc=num_proc,
        remove_columns=original_columns1,
        load_from_cache_file=False
    )
    print(f"ds1 keys = {ds1.column_names}, len = {len(ds1)}")
    print(f"data example: {ds1[0]}")

    ds1.set_format(type="torch")
    return ds1


def preprocess_function_question(examples=None, tokenizer=None):
    new_examples = {
        "input_ids": [],
        "loss_mask": [],
        "attention_mask": [],
        "confidence_loss_mask": [],
        "confidence": [],
    }
    for i in range(len(examples['id'])):
        sys_prompt = examples['sys_prompt'][i] or "Please reason step by step, and put your final answer within \\boxed{}."
        question = examples['question'][i]
        output = examples['output'][i]
        messages = [
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": question},
            {"role": "assistant", "content": output}
        ]
        prompt_text = tokenizer.apply_chat_template(
            messages[:-1],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True
        )
        full_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
            enable_thinking=True
        )

        prompt_tokens = tokenizer(prompt_text, return_tensors="pt")
        full_tokens = tokenizer(full_text, return_tensors="pt")
        input_ids = full_tokens["input_ids"]

        class_loss_mask = torch.ones_like(input_ids, dtype=torch.bool)
        prompt_length = prompt_tokens["input_ids"].shape[1]
        class_loss_mask[:, :prompt_length] = 0

        confidence_loss_mask = torch.zeros_like(input_ids, dtype=torch.bool)
        confidence = torch.full_like(input_ids, 0, dtype=torch.long)
        stop_think_indices = (input_ids[0, prompt_length:] == tokenizer.stop_think_id).nonzero(as_tuple=True)[0]

        if stop_think_indices.numel() > 0:
            stop_think_pos = stop_think_indices[0].item() + prompt_length
            confidence_loss_mask[:, prompt_length : stop_think_pos + 1] = True
            confidence[:, prompt_length : stop_think_pos + 1] = torch.arange(stop_think_pos - prompt_length, -1, -1)

        new_examples["input_ids"].append(input_ids)
        new_examples["loss_mask"].append(class_loss_mask)
        new_examples["attention_mask"].append(torch.ones_like(full_tokens["attention_mask"]))
        new_examples["confidence_loss_mask"].append(confidence_loss_mask)
        new_examples["confidence"].append(confidence)

    return new_examples


def preprocess_function(examples, tokenizer):
    new_examples = {
        "attention_mask": [],
        "input_ids": [],
        "loss_mask": []
    }
    # import pdb; pdb.set_trace()
    for i in range(len(examples['id'])):
        convroles = ["user", "assistant"]
        roles = {"human": "user", "gpt": "assistant"}
        source = examples['conversations'][i]
        if not source: continue

        if source[0]["from"] == "system":
            messages = [{"role": "system", "content": source[0]["value"]}]                    
            source = source[1:]
        else:
            messages = [{
                "role": "system",
                "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
            }]
        if roles[source[0]["from"]] != "user":
            # Skip the first one if it is not from human
            source = source[1:]

        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == convroles[j % 2], f"{i}"

            messages.append({"role": role, "content": sentence["value"]})

        conversation = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
        )

        input_ids = tokenizer(
            conversation,
            return_tensors="pt",
            max_length=16384,
            add_special_tokens=False,
            truncation=True,
        ).input_ids[0]
        loss_mask = torch.ones_like(input_ids)

        # sep = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        sep = "<|im_end|>\n<|im_start|>assistant\n"
        sep2 = "<|im_end|>\n<|im_start|>user\n"
        sep3 = "<|im_end|>\n<|im_start|>assistant\n"
        turns = conversation.split(sep2)
        if len(turns) == 1:
            # Handle single-turn conversations
            parts = turns[0].split(sep)
            instruction_part = parts[0] + sep
            instruction_len = len(tokenizer(instruction_part).input_ids)
            loss_mask[:instruction_len] = 0
            loss_mask[-2:] = 0
        else:
            # Handle multi-turn conversations
            cur_len = 0
            sep2_len = 5
            for idx, turn in enumerate(turns):
                cur_sep = sep
                if sep not in turn:
                    cur_sep = sep3

                parts = turn.split(cur_sep)
                instruction_part = parts[0] + cur_sep

                instruction_len = len(tokenizer(instruction_part).input_ids)
                loss_mask[cur_len:cur_len + instruction_len] = 0

                turn_len = len(tokenizer(turn).input_ids)
                cur_len += turn_len
                cur_len += sep2_len

                loss_mask[cur_len - sep2_len:cur_len] = 0
            loss_mask[-2:] = 0
        
       
        attention_mask = torch.ones_like(loss_mask)

        new_examples["input_ids"].append(input_ids[None, :])
        new_examples["loss_mask"].append(loss_mask[None, :])
        new_examples["attention_mask"].append(attention_mask[None, :])
    return new_examples



def preprocess_function_1turn(examples, tokenizer):
    new_examples = {
        "attention_mask": [],
        "input_ids": [],
        "loss_mask": []
    }
    start_think_text = tokenizer.start_think_text
    stop_think_text = tokenizer.stop_think_text
    
    for i in range(len(examples['id'])):
        conv_roles = ["user", "assistant"]
        roles = {
            "human": "user", 
            "user": "user",
            "gpt": "assistant",
            "assistant": "assistant"
        }
        source = examples['conversations'][i]
        if not source: continue

        sys_prompt = None
        if source[0]["from"] == "system":
            sys_prompt = source[0]["value"]
            if isinstance(sys_prompt, list):
                sys_prompt = sys_prompt[0]["content"]
            source = source[1:]
        if not sys_prompt:
            sys_prompt = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
        messages = [{"role": "system", "content": sys_prompt}]

        if roles[source[0]["from"]] != "user" or roles[source[1]["from"]] != "assistant":
            # source = source[1:]
            continue

        question = source[0]["value"]
        answer = source[1]["value"]

        if isinstance(question, list):
            question = question[0]["content"]
        messages.append({"role": "user", "content": question})
        conversation = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            # enable_thinking=True
        )
        if start_think_text in conversation:
            conversation = conversation.split(start_think_text)[0]

        input_ids = tokenizer.encode(conversation, add_special_tokens=False)
        if len(input_ids) > 4 * 1024 - 100: continue
        loss_mask = [0] * len(input_ids)
        
        if isinstance(answer, list):
            if not answer or answer[0]["type"] != "think" or answer[-1]["type"] != "output":
                continue
            if not answer[0]["content"].startswith(start_think_text):
                answer[0]["content"] = start_think_text + "\n" + answer[0]["content"]
            if not answer[-2]["content"].endswith(stop_think_text):
                answer[-2]["content"] += stop_think_text
                
            output_content = ''.join([step["content"] for step in answer])
            if output_content.count(start_think_text) != 1 or \
                output_content.count(stop_think_text) != 1 or \
                (output_content.find(stop_think_text) - output_content.find(start_think_text)) <= 0:
                continue

            for j, step in enumerate(answer):
                step_content = step["content"]
                step_input_ids = tokenizer.encode(step_content, add_special_tokens=False)
                input_ids.extend(step_input_ids)
                loss_mask.extend([1] * len(step_input_ids))

        elif isinstance(answer, str):
            if start_think_text not in answer:
                answer = start_think_text + "\n" + answer
            if answer.count(start_think_text) != 1 or \
                answer.count(stop_think_text) != 1 or \
                (answer.find(stop_think_text) - answer.find(start_think_text)) <= 0:
                continue
            answer_input_ids = tokenizer.encode(answer, add_special_tokens=False)
            input_ids.extend(answer_input_ids)
            loss_mask.extend([1] * len(answer_input_ids))

            if hasattr(tokenizer, "stop_think_id"):
                stop_think_pos = answer_input_ids.index(tokenizer.stop_think_id)
            else:
                stop_think_pos = answer_input_ids.index(tokenizer.stop_conv_id)

        if len(input_ids) > 4 * 1024:
            input_ids = input_ids[:4 * 1024]
            loss_mask = loss_mask[:4 * 1024]

        if not any(loss_mask): continue

        input_ids = torch.tensor(input_ids)
        loss_mask = torch.tensor(loss_mask)
        
        if hasattr(tokenizer, 'start_think_id') and tokenizer.start_think_id not in input_ids:
            print(f"There's no {start_think_text} in sentence id={examples['id'][i]} truncation_input={tokenizer.decode(input_ids)!r}")
            continue
        
        attention_mask = torch.ones_like(loss_mask)

        new_examples["input_ids"].append(input_ids[None, :])
        new_examples["loss_mask"].append(loss_mask[None, :])
        new_examples["attention_mask"].append(attention_mask[None, :])
    return new_examples


def convert_to_sharegpt_format(data_path_list, out_path):
    datas = []
    for data_path in tqdm(data_path_list):
        sub_data = []
        if data_path.endswith(".json"):
            sub_data = json.load(open(data_path))
        elif data_path.endswith(".jsonl"):
            with open(data_path, "r") as f:
                for line in f:
                    sub_data.append(json.loads(line))
        else:
            raise ValueError(f"Invalid data path: {data_path}")
        print(f"Loaded {len(sub_data)} datas from {data_path}")
        datas.extend(sub_data)
    
    role_map = {
        "user": "human", 
        "human": "human",
        "assistant": "gpt",
        "gpt": "gpt"
    }
    roles = ["human", "gpt"]
    train_data = []
    test_data = []
    id = 0
    data_length = len(datas)

    import random
    random.seed(42)
    random.shuffle(datas)

    for data in tqdm(datas):
        conversations = []
        if isinstance(data, dict):
            input_conversations = data["conversations"]
        else:
            input_conversations = data
        if not input_conversations: continue

        conv = input_conversations[0]
        if conv.get("role", conv.get("from", "")) == "system":
            conversations.append({
                "from": "system",
                "value": conv.get("content", conv.get("value", ""))
            })
            input_conversations = input_conversations[1:]
        for i, conv in enumerate(input_conversations):
            role = conv.get("role", conv.get("from", ""))
            content = conv.get("content", conv.get("value", ""))
            if role not in role_map:
                break
            if role_map[role] != roles[i % 2]:
                break
            conversations.append({
                "from": role_map[role],
                "value": content
            })
        if len(conversations) < 2: continue

        if id < 0.95 * data_length:
            train_data.append({
                "id": str(id),
                "conversations": conversations
            })
        else:
            test_data.append({
                "id": str(id),
                "conversations": conversations
            })
        id += 1
    
    train_out_path = out_path.replace(".json", "_train.json")
    test_out_path = out_path.replace(".json", "_test.json")
    with open(train_out_path, "w") as f:
        json.dump(train_data, f, ensure_ascii=False, indent=2)
    with open(test_out_path, "w") as f:
        json.dump(test_data, f, ensure_ascii=False, indent=2)

