import transformers
import utils.conversation as conversation_lib
from typing import Dict, List
import torch
import numpy as np
from torch_geometric.data import Data

IGNORE_INDEX = -100


def preprocess_v1(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    mode: str = 'train',
    answers: List[str] = None,
    mem_size: int = 128,
) -> Dict:
    conv = conversation_lib.conv_templates['vicuna_v1_1'].copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    # Apply prompt templates
    # [[{'from': role}, {'value': text}], [{'from': role}, {'value': text}]]
    conversations = []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            # Skip the first one if it is not from human
            source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt())

    # Tokenize conversations
    inputs = tokenizer(
        conversations,
        return_tensors="pt",
        padding="longest",
        max_length=tokenizer.model_max_length,
        # max_length=2048,
        truncation=True,
        padding_side='right' if mode == 'train' else 'left',
    )
    input_ids = inputs.input_ids
    if mode != 'train':
        targets = tokenizer(
            answers,
            return_tensors="pt",
            padding="longest",
            max_length=100,
            truncation=True,
            padding_side='right',
        ).input_ids
    else:
        targets = input_ids.clone()

        assert conv.sep_style == conversation_lib.SeparatorStyle.TWO

        # Mask targets
        sep = conv.sep + conv.roles[1] + ": "
        for idx, (conversation, target) in enumerate(zip(conversations, targets)):
            total_len = int(target.ne(tokenizer.pad_token_id).sum())

            rounds = conversation.split(conv.sep2)
            cur_len = 1
            target[:cur_len] = IGNORE_INDEX
            for i, rou in enumerate(rounds):
                if rou == "":
                    break

                parts = rou.split(sep)
                if len(parts) != 2:
                    break
                parts[0] += sep
                round_len = len(tokenizer(rou).input_ids)
                instruction_len = len(tokenizer(parts[0]).input_ids) - 2

                target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

                # prompt loss
                parts_prompt = rou.split(" USER: ")
                if len(parts_prompt) != 2:
                    break
                parts_prompt[0] += " USER: "
                prompt_start = len(tokenizer(parts_prompt[0]).input_ids) - 2

                parts_prompt = rou.split("\nQuestion: ")
                if len(parts_prompt) != 2:
                    break
                prompt_end = len(tokenizer(parts_prompt[0]).input_ids) - 1
                target[cur_len + prompt_start + mem_size + 1 : cur_len + prompt_end] = input_ids[idx, cur_len + prompt_start + mem_size + 1 : cur_len + prompt_end].clone()

                cur_len += round_len
            target[cur_len:] = IGNORE_INDEX

            if cur_len < tokenizer.model_max_length:
                if cur_len != total_len:
                    target[:] = IGNORE_INDEX
                    print(
                        f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                        f" (ignored)"
                    )

    return dict(
        input_ids=input_ids,
        attention_mask=inputs.attention_mask,
        labels=targets,
    )

def preprocess_llama_2(
    sources,
    tokenizer,
    mode: str = 'train',
    answers: List[str] = None,
    mem_size: int = 128,
) -> Dict:
    conv = conversation_lib.conv_templates['llama2'].copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    # Apply prompt templates
    conversations = []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            # Skip the first one if it is not from human
            source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt())

    # Tokenize conversations
    inputs = tokenizer(
        conversations,
        return_tensors="pt",
        padding="longest",
        max_length=2048,
        truncation=True,
        padding_side='right' if mode == 'train' else 'left',
    )
    input_ids = inputs.input_ids
    
    if mode != 'train':
        targets = tokenizer(
            answers,
            return_tensors="pt",
            padding="longest",
            max_length=100,
            truncation=True,
            padding_side='right',
        ).input_ids
    else:
        targets = input_ids.clone()

        assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA2
        # Mask targets
        sep = "[/INST] "
        for idx, (conversation, target) in enumerate(zip(conversations, targets)):
            total_len = int(target.ne(tokenizer.pad_token_id).sum())

            rounds = conversation.split(conv.sep2)
            cur_len = 1
            target[:cur_len] = IGNORE_INDEX
            for i, rou in enumerate(rounds):
                if rou == "":
                    break

                parts = rou.split(sep)
                if len(parts) != 2:
                    break
                parts[0] += sep

                round_len = len(tokenizer(rou).input_ids)
                instruction_len = len(tokenizer(parts[0]).input_ids) - 2

                target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

                # prompt loss
                parts_prompt = rou.split("\nQuestion: ")
                if len(parts_prompt) != 2:
                    break
                prompt_len = len(tokenizer(parts_prompt[0]).input_ids) - 1
                target[cur_len + 3 + mem_size + 1 : cur_len + prompt_len] = input_ids[idx, cur_len + 3 + mem_size + 1 : cur_len + prompt_len].clone()

                cur_len += round_len
            target[cur_len:] = IGNORE_INDEX

            if cur_len < tokenizer.model_max_length:
                if cur_len != total_len:
                    target[:] = IGNORE_INDEX
                    print(
                        f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                        f" (ignored)"
                    )

    return dict(
        input_ids=input_ids,
        attention_mask=inputs.attention_mask,
        labels=targets,
    )

def preprocess_llama_3(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    mode: str = 'train',
    answers: List[str] = None,
    mem_size: int = 128,
) -> Dict:
    conv = conversation_lib.conv_templates['llama3'].copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    # Apply prompt templates
    # [[{'from': role}, {'value': text}], [{'from': role}, {'value': text}]]
    conversations = []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            # Skip the first one if it is not from human
            source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt())

    # Tokenize conversations
    inputs = tokenizer(
        conversations,
        return_tensors="pt",
        padding="longest",
        # max_length=tokenizer.model_max_length,
        max_length=4096,
        truncation=True,
        padding_side='right' if mode == 'train' else 'left',
    )
    input_ids = inputs.input_ids
    if mode != 'train':
        targets = tokenizer(
            answers,
            return_tensors="pt",
            padding="longest",
            max_length=100,
            truncation=True,
            padding_side='right',
        ).input_ids
    else:
        targets = input_ids.clone()

        assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA3

        # Mask targets
        sep = "<|start_header_id|>assistant<|end_header_id|>\n\n"
        for idx, (conversation, target) in enumerate(zip(conversations, targets)):
            total_len = int(target.ne(tokenizer.pad_token_id).sum()) + 2 # padding token = eos token, so add 2 eos token

            rounds = conversation.split(conv.sep2)
            cur_len = 1
            target[:cur_len] = IGNORE_INDEX
            for i, rou in enumerate(rounds):
                if rou == "":
                    break

                parts = rou.split(sep)
                if len(parts) != 2:
                    break
                parts[0] += sep
                round_len = len(tokenizer(rou).input_ids) - 1 # 没有额外的终止符
                instruction_len = len(tokenizer(parts[0]).input_ids) - 1 #换行符不会出现空格多算一个的情况

                target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

                # prompt loss
                parts_start = "<|start_header_id|>user<|end_header_id|>\n\n"
                prompt_start = len(tokenizer(parts_start).input_ids) - 1

                parts_prompt = rou.split("\nQuestion: ")
                if len(parts_prompt) != 2:
                    break
                prompt_end = len(tokenizer(parts_prompt[0]).input_ids) - 1
                target[cur_len + prompt_start + mem_size + 1 : cur_len + prompt_end] = input_ids[idx, cur_len + prompt_start + mem_size + 1 : cur_len + prompt_end].clone()

                cur_len += round_len
            target[cur_len:] = IGNORE_INDEX

            if cur_len < tokenizer.model_max_length:
                if cur_len != total_len:
                    target[:] = IGNORE_INDEX
                    print(
                        f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                        f" (ignored)"
                    )

    return dict(
        input_ids=input_ids,
        attention_mask=inputs.attention_mask,
        labels=targets,
    )

def pad_1d_unsqueeze(x, padlen):
    # x = x + 1  # pad id = 0
    xlen = x.size(0)
    if xlen < padlen:
        new_x = x.new_zeros([padlen], dtype=x.dtype)
        new_x[:xlen] = x
        x = new_x
    return x.unsqueeze(0)


def pad_2d_unsqueeze(x, padlen):
    # x = x + 1  # pad id = 0
    xlen, xdim = x.size()
    if xlen < padlen:
        new_x = x.new_zeros([padlen, xdim], dtype=x.dtype)
        new_x[:xlen, :] = x
        x = new_x
    return x.unsqueeze(0)


def pad_spatial_pos_unsqueeze(x, padlen):
    # x = x + 1
    xlen = x.size(0)
    if xlen < padlen:
        new_x = x.new_zeros([padlen, padlen], dtype=x.dtype)
        new_x[:xlen, :xlen] = x
        x = new_x
    return x.unsqueeze(0)


def pad_shortest_path_unsqueeze(x, padlen1, padlen2):
    # x = x + 1
    xlen, dist_len = x.size(0), x.size(-1)
    if xlen < padlen1 or dist_len < padlen2:
        new_x = x.new_full([padlen1, padlen1, padlen2], fill_value=-1, dtype=x.dtype)
        new_x[:xlen, :xlen, :dist_len] = x
        x = new_x
    return x.unsqueeze(0)


def collator_graph_data(batch, max_node=512):
    graphs = []
    for entry in batch:
        if entry['graph'].x.size(0) <= max_node:
            graphs.append(entry['graph'])
        else:
            new_graph = Data()
            new_graph.rel_postion = entry['graph'].rel_postion[:max_node, :max_node]
            new_graph.graph_attention_mask = entry['graph'].graph_attention_mask[:max_node]
            new_graph.x = entry['graph'].x[:max_node]
            new_graph.edge_attr = entry['graph'].edge_attr
            new_graph.edge_type = entry['graph'].edge_type[:max_node, :max_node, :]
            graphs.append(new_graph)

    # graphs = [entry['graph'] for entry in batch if entry['graph'].x.size(0) <= max_node]
    graphs = [
        (
            graph.rel_postion,
            graph.graph_attention_mask,
            graph.x,
            graph.edge_attr,
            graph.edge_type
        )
        for graph in graphs
    ]
    (
        spatial_poses,
        graph_attention_masks,
        xs,
        edge_attrs,
        edge_types
    ) = zip(*graphs)

    max_node_num = max(i.size(0) for i in xs)
    max_edge_num = max(i.size(0) for i in edge_attrs)
    max_dist = max(i.size(-1) for i in edge_types)

    x = torch.cat([pad_2d_unsqueeze(i, max_node_num) for i in xs])

    spatial_pos = torch.cat(
        [pad_spatial_pos_unsqueeze(i, max_node_num) for i in spatial_poses]
    )

    graph_attention_mask = torch.cat([pad_1d_unsqueeze(i, max_node_num) for i in graph_attention_masks])

    edge_attrs = torch.cat([pad_2d_unsqueeze(i, max_edge_num) for i in edge_attrs])

    edge_types = torch.cat([pad_shortest_path_unsqueeze(i, max_node_num, max_dist) for i in edge_types])

    return Data(
        graph_attention_mask=graph_attention_mask,
        spatial_pos=spatial_pos,
        x=x,
        edge_attr=edge_attrs,
        edge_type=edge_types
    )

def output_decode(eval_output, eval_label, tokenizer):
    eval_decode_output = []
    eval_decode_label = []
    assert len(eval_output) == len(eval_label)
    for i in range(len(eval_output)):
        batch_output = eval_output[i]
        label_output = eval_label[i]
        eval_decode_output.extend(tokenizer.batch_decode(batch_output, skip_special_tokens=True))
        eval_decode_label.extend(tokenizer.batch_decode(label_output, skip_special_tokens=True))
    assert len(eval_decode_label) == len(eval_decode_output)

    return eval_decode_output, eval_decode_label



