import argparse
import json
from transformers import (
    AutoTokenizer,
    RobertaTokenizer,
    RobertaConfig,
    AutoConfig,
)
from preprocess import AST
import torch
import torch.nn.functional as F
import math
from tqdm import tqdm

from models import LlavaCodeConfig, LlavaCodeForConditionalGeneration
from eval_metric import compute_metric_stmt
from eval_metric_cceval import compute_metric_stmt_cceval
from datamodule.const import STRUCTURE_TOKEN, FIMMAP

device = torch.device("cuda:0")


def tokenize_patches(tokens, patch_length, tokenizer):

    num_injection_tokens = math.ceil(len(tokens) / patch_length)

    tokens_ids = []
    for i in range(num_injection_tokens):
        patch = tokens[i * patch_length: (i + 1) * patch_length]
        patch_tokens = [tokenizer.cls_token, "<encoder-only>", tokenizer.sep_token] \
            + patch + [tokenizer.sep_token]
        patch_ids = tokenizer.convert_tokens_to_ids(patch_tokens)
        tokens_ids.extend(patch_ids)
    tokens_ids = torch.tensor(tokens_ids, dtype=torch.long)

    return tokens_ids, num_injection_tokens


def prepare_prompt(args,
                   tokenizer,
                   structure_tokenizer,
                   fim_tokens,
                   left_cxt,
                   right_cxt=None,
                   crossfile_cxt=None):
    """Dataset type: 10 chunks of cross-file context, 10 lines each, stored as an array
    """
    fim_prefix, fim_suffix, fim_middle = fim_tokens

    if args.data_prefix == 'code_cfc_uxc':
        assert False
        # splitting context into chunks
        # one chunk for one file
        lines = crossfile_cxt.splitlines()[1:]  # removing the "Here are some examples..." line
        skip = False
        current_lines = []
        chunks = []
        for line in lines:
            if line.startswith('# the below code fragment can be found in:'):
                skip = True
                if current_lines:
                    chunks.append('\n'.join(current_lines))
                current_lines = []
            elif skip:  # skipping the file path
                skip = False
            elif line:
                current_lines.append(line.strip())
        if current_lines:
            chunks.append('\n'.join(current_lines))

        # restrict number of injection tokens (RAG files)
        num_injection_tokens = min(args.num_structure_tokens, len(chunks))
        chunks = chunks[:num_injection_tokens]

        structure_ids = []
        for cfc in chunks:
            cfc_tokens = structure_tokenizer.tokenize(cfc)
            cfc_tokens = cfc_tokens[:args.max_structure_length - 4]  # 4 special tokens for unixcoder
            cfc_tokens = [structure_tokenizer.cls_token, "<encoder-only>", structure_tokenizer.sep_token] \
                + cfc_tokens + [structure_tokenizer.sep_token]
            chunk_ids = structure_tokenizer.convert_tokens_to_ids(cfc_tokens)
            structure_ids.extend(F.pad(torch.tensor(chunk_ids), (0, args.max_structure_length-len(chunk_ids)), value=structure_tokenizer.pad_token_id))
        structure_ids = torch.tensor(structure_ids, dtype=torch.long)

        left_cxt_truncated = tokenizer.decode(tokenizer.encode(left_cxt)[-(args.max_seq_length - args.gen_length - num_injection_tokens - args.right_context_length):])
        right_cxt_truncated = tokenizer.decode(tokenizer.encode(right_cxt)[:args.right_context_length])
        prompt = f"{fim_prefix}{left_cxt_truncated}{fim_suffix}{right_cxt_truncated}{STRUCTURE_TOKEN * num_injection_tokens}{fim_middle}"
        # prompt = f"# Here are some relevant code fragments from other files of the repo:{'<CODE_STRUCTURE>' * num_injection_tokens}{fim_prefix}{left_cxt_truncated}{fim_suffix}{right_cxt_truncated}{fim_middle}"

        return prompt, structure_ids, torch.tensor([num_injection_tokens])

    elif args.data_prefix in ['code_cfc_jina', 'code_cfc_qwen']:
        # splitting context into chunks
        # one chunk for one file
        lines = crossfile_cxt.splitlines()[1:]  # removing the "Here are some examples..." line
        skip = False
        current_lines = []
        chunks = []
        for line in lines:
            if line.startswith('# the below code fragment can be found in:'):
                skip = True
                if current_lines:
                    chunks.append('\n'.join(current_lines))
                current_lines = []
            elif skip:  # skipping the file path
                skip = False
            elif line:
                current_lines.append(line)
        if current_lines:
            chunks.append('\n'.join(current_lines))

        # restrict number of injection tokens (RAG files)
        num_injection_tokens = min(args.num_structure_tokens, len(chunks))
        chunks = chunks[:num_injection_tokens]

        structure_ids = []
        for cfc in chunks:
            cfc_ids = structure_tokenizer(cfc, return_tensors='pt', truncation=True, max_length=args.max_structure_length).input_ids[0]
            structure_ids.extend(F.pad(cfc_ids, (0, args.max_structure_length-len(cfc_ids)), value=structure_tokenizer.pad_token_id))
        structure_ids = torch.tensor(structure_ids, dtype=torch.long)


        lr_budget = args.max_seq_length - args.gen_length - 3  # 3 tokens for FIM
        rc_budget = int(lr_budget / (args.lc_rc_ratio + 1))
        lc_budget = int(rc_budget * args.lc_rc_ratio)

        left_cxt_truncated = tokenizer.decode(tokenizer.encode(left_cxt)[-lc_budget:])
        right_cxt_truncated = tokenizer.decode(tokenizer.encode(right_cxt)[:rc_budget])

        # left_cxt_truncated = tokenizer.decode(tokenizer.encode(left_cxt)[-(args.max_seq_length - args.gen_length - num_injection_tokens - args.right_context_length):])
        # right_cxt_truncated = tokenizer.decode(tokenizer.encode(right_cxt)[:args.right_context_length])
        prompt = f"{fim_prefix}{left_cxt_truncated}{fim_suffix}{right_cxt_truncated}{STRUCTURE_TOKEN * num_injection_tokens}{fim_middle}"

        return prompt, structure_ids, torch.tensor([num_injection_tokens])

    elif args.data_prefix == 'ast_cfc':
        # splitting context into chunks
        # one chunk for one file
        lines = crossfile_cxt.splitlines()[1:]  # removing the "Here are some examples..." line
        skip = False
        current_lines = []
        chunks = []
        for line in lines:
            if line.startswith('# the below code fragment can be found in:'):
                skip = True
                if current_lines:
                    chunks.append('\n'.join(current_lines))
                current_lines = []
            elif skip:  # skipping the file path
                skip = False
            elif line:
                current_lines.append(line.strip())
        if current_lines:
            chunks.append('\n'.join(current_lines))

        # restrict number of injection tokens (RAG files)
        num_injection_tokens = min(args.num_structure_tokens, len(chunks))
        chunks = chunks[:num_injection_tokens]

        structure_ids = []
        for cfc in chunks:
            ast_tokens = AST(cfc.replace('#', ''), 'python', structure_tokenizer)  # decommenting
            ast_tokens = ast_tokens[:args.max_structure_length - 4]  # 4 special tokens for unixcoder
            chunk_tokens = [structure_tokenizer.cls_token, "<encoder-only>", structure_tokenizer.sep_token] \
                + ast_tokens + [structure_tokenizer.sep_token]
            chunk_ids = structure_tokenizer.convert_tokens_to_ids(chunk_tokens)
            structure_ids.extend(F.pad(torch.tensor(chunk_ids), (0, args.max_structure_length-len(chunk_ids)), value=structure_tokenizer.pad_token_id))
        structure_ids = torch.tensor(structure_ids, dtype=torch.long)

        left_cxt_truncated = tokenizer.decode(tokenizer.encode(left_cxt)[-(args.max_seq_length - args.gen_length - num_injection_tokens - args.right_context_length):])
        right_cxt_truncated = tokenizer.decode(tokenizer.encode(right_cxt)[:args.right_context_length])
        prompt = f"{fim_prefix}{left_cxt_truncated}{fim_suffix}{right_cxt_truncated}{STRUCTURE_TOKEN * num_injection_tokens}{fim_middle}"

        return prompt, structure_ids, torch.tensor([num_injection_tokens])

    # elif args.data_prefix == 'codeast_cfc':

    #     structure_tokens = structure_tokenizer.tokenize(crossfile_cxt)
    #     # AST function ignores comments
    #     structure_tokens += AST(crossfile_cxt.replace('#', ''), 'python', structure_tokenizer)
    #     patch_length = args.max_structure_length - 4  # 4 special tokens for unixcoder
    #     structure_ids, num_injection_tokens = tokenize_patches(structure_tokens, patch_length, structure_tokenizer)
    #     structure_ids = F.pad(structure_ids,
    #                           (0, args.max_structure_length * num_injection_tokens - len(structure_ids)),
    #                           value=structure_tokenizer.pad_token_id)

    #     left_cxt_truncated = tokenizer.decode(tokenizer.encode(left_cxt)[-(args.max_seq_length - args.gen_length - num_injection_tokens - args.right_context_length):])
    #     right_cxt_truncated = tokenizer.decode(tokenizer.encode(right_cxt)[:args.right_context_length])
    #     prompt = f'<fim_prefix>{left_cxt_truncated}' + f'<fim_suffix>{right_cxt_truncated}' + '<CODE_STRUCTURE>' * num_injection_tokens + '<fim_middle>'

    #     return prompt, structure_ids, torch.tensor([num_injection_tokens])

    elif args.data_prefix == 'default':

        lr_budget = args.max_seq_length - args.gen_length - 3 # 3 tokens for FIM
        rc_budget = int(lr_budget / (args.lc_rc_ratio + 1))
        lc_budget = int(rc_budget * args.lc_rc_ratio)

        left_cxt_truncated = tokenizer.decode(tokenizer.encode(left_cxt)[-lc_budget:])
        right_cxt_truncated = tokenizer.decode(tokenizer.encode(right_cxt)[:rc_budget])

        prompt = f"{fim_prefix}{left_cxt_truncated}{fim_suffix}{right_cxt_truncated}{fim_middle}"

        return prompt, None, None

    elif args.data_prefix == 'default_cfc':

        assert crossfile_cxt is not None

        # making the same lr_budget as for other experiments, not considering cfc length
        lr_budget = args.max_seq_length - args.gen_length - 3 # 3 tokens for FIM
        rc_budget = int(lr_budget / (args.lc_rc_ratio + 1))
        lc_budget = int(rc_budget * args.lc_rc_ratio)

        left_cxt_truncated = tokenizer.decode(tokenizer.encode(left_cxt)[-lc_budget:])
        right_cxt_truncated = tokenizer.decode(tokenizer.encode(right_cxt)[:rc_budget])

        crossfile_cxt_truncated = tokenizer.decode(tokenizer.encode(crossfile_cxt)[:args.cfc_seq_length])
        if 'starcoder' in args.text_model_id.lower():
            prompt = f'{fim_prefix}{left_cxt_truncated}{fim_suffix}{right_cxt_truncated}{crossfile_cxt_truncated}{fim_middle}'
        elif 'qwen' in args.text_model_id.lower():
            # prompt = f'{crossfile_cxt_truncated}{fim_prefix}{left_cxt_truncated}{fim_suffix}{right_cxt_truncated}{fim_middle}'
            prompt = f'{fim_prefix}{left_cxt_truncated}{fim_suffix}{right_cxt_truncated}{crossfile_cxt_truncated}{fim_middle}'

        return prompt, None, None

    else:

        raise ValueError(f'Unrecognized data_prefix: {args.data_prefix}')


def build_dataset(args, code_tokenizer, ast_tokenizer, fim_tokens):
    with open(args.prompt_file) as f:
        raw_data = [json.loads(line) for line in f.readlines()]

    data = []
    for entry in raw_data:

        left_cxt = entry["prompt"]
        right_cxt = entry["right_context"]
        crossfile_cxt = None
        if 'crossfile_context' in entry:
            crossfile_cxt = entry["crossfile_context"] if type(entry["crossfile_context"]) == str else entry["crossfile_context"]['text']

        entry['llm_prompt'], entry['structure_ids'], entry['num_structure_tokens'] = \
            prepare_prompt(
                args, code_tokenizer, ast_tokenizer, fim_tokens, left_cxt, right_cxt, crossfile_cxt)

        data.append(entry)

    return data


def remove_tokens(s, tokens=["<|fim_prefix|>", "<|fim_middle|>", "<|fim_suffix|>", "<|fim_pad|>", "<|repo_name|>", "<|file_sep|>", "<|im_start|>", "<|im_end|>"]):
    import re
    skip_tokens = [
        "<\|fim_prefix\|>", "<\|fim_middle\|>", "<\|fim_suffix\|>", "<\|fim_pad\|>",
        "<\|repo_name\|>", "<\|file_sep\|>", "<\|im_start\|>", "<\|im_end\|>"
    ]
    pattern = "|".join(skip_tokens)
    return re.sub(pattern, "", s)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument("--text_model_id", type=str, required=True)
    parser.add_argument("--structure_model_id", type=str, required=True)
    parser.add_argument("--language", type=str, required=True, help="language name")
    parser.add_argument("--model_checkpoint", type=str)
    parser.add_argument("--projector_checkpoint", type=str)
    parser.add_argument("--task", type=str, choices=["line_completion", "api_completion", "function_completion"])
    parser.add_argument("--prompt_file", type=str, default=None, help="file with a list of prompts")
    parser.add_argument("--gen_length", type=int, default=50, help="max length of generated token sequence")
    parser.add_argument("--max_seq_length", type=int, default=2048, help="max length of prompt")
    parser.add_argument("--max_structure_length", type=int, default=512, help="max length of structure sequence")
    parser.add_argument("--right_context_length",
                        type=int,
                        default=512,
                        help="For model_type=codelm_leftright_context: Text sequence length corresponding to the right context")
    parser.add_argument("--cfc_seq_length",
                        type=int,
                        default=512,
                        help="For model_type=codelm_cfc: Text sequence length corresponding to the retrieved nodes")
    parser.add_argument("--num_structure_tokens", type=int, required=True, help='number of embeddings reserved for RAG injection')
    parser.add_argument("--output_dir", type=str, default="output_dir", help="output directory to save predictions")
    parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
    parser.add_argument("--only_compute_metric", action="store_true", help="only compute metric")
    parser.add_argument("--compute_cceval_metric", type=lambda x: bool(int(x)), help="use cceval metric")
    parser.add_argument("--data_prefix", type=str, help="Determines data preprocessing")
    parser.add_argument('--config', type=str, help='path to args config')
    parser.add_argument("--do_sample", action='store_true')
    parser.add_argument("--lc_rc_ratio", default=2.0)

    args = parser.parse_args()
    print('Input args:', args)

    code_tokenizer = AutoTokenizer.from_pretrained(args.text_model_id, use_fast=False)
    code_tokenizer.add_tokens([STRUCTURE_TOKEN])
    if code_tokenizer.pad_token_id is None:  # case with starcoder
        code_tokenizer.pad_token_id = code_tokenizer.eos_token_id
    structure_token_id = code_tokenizer.convert_tokens_to_ids(STRUCTURE_TOKEN)

    structure_tokenizer = AutoTokenizer.from_pretrained(args.structure_model_id)

    structure_config = AutoConfig.from_pretrained(args.structure_model_id)
    structure_config.model_id = args.structure_model_id
    structure_config.pad_token_id = structure_tokenizer.pad_token_id
    text_config = AutoConfig.from_pretrained(args.text_model_id)
    text_config.model_id = args.text_model_id

    # TODO: is it possible to include <CODE_STRUCTURE> -> vector mapping without resizing embeddings?
    # possible implementation: qwen tokens <|repo_name|> and <|file_sep|> tokens
    # this is necessary for resize_token_embeddings() call
    text_config.vocab_size = text_config.vocab_size + 1  # for a new <CODE_STRUCTURE>
    configuration = LlavaCodeConfig(structure_config, text_config,
                                    pad_token_id=code_tokenizer.pad_token_id,
                                    structure_token_id=structure_token_id,
                                    injector=False)
    print('tokenizer shapes:', code_tokenizer.vocab_size, len(code_tokenizer))  # delete later

    if args.model_checkpoint:
        print(f'Loading model from checkpoint: {args.model_checkpoint}')
        model = LlavaCodeForConditionalGeneration \
            .load_from_checkpoint(args.model_checkpoint, config=configuration).to(device)
        print(f'after checkpoint: {model.model.multi_modal_projector.linear_1.weight.data.norm(2)}')
    else:
        model = LlavaCodeForConditionalGeneration(configuration).to(device)
    if args.projector_checkpoint:
        print(f'Loading projection weighs from {args.projector_checkpoint}')
        model.multi_modal_projector.load_state_dict(torch.load(args.projector_checkpoint))
    model.eval()

    if 'qwen' in args.text_model_id.lower():
        fim_tokens = FIMMAP['qwen2.5']
    elif 'starcoder' in args.text_model_id.lower():
        fim_tokens = FIMMAP['starcoder']
    else:
        raise NotImplementedError('No such model in FIM mapping')
    print('fim tokens:', fim_tokens)
    data = build_dataset(args, code_tokenizer, structure_tokenizer, fim_tokens)

    all_preds = []
    for entry in tqdm(data):

        entropies = []
        with torch.no_grad():

            inputs = code_tokenizer(entry['llm_prompt'], return_tensors='pt').to(device)
            cut_at = inputs.input_ids.shape[1]
            if args.data_prefix not in ['default', 'default_cfc']:

                structure_ids = entry['structure_ids'].to(device)
                num_structure_tokens = entry['num_structure_tokens'].to(device)
                cur_pred = model.generate(
                    **inputs,
                    do_sample=args.do_sample,
                    structure_values=structure_ids,
                    num_structure_tokens=num_structure_tokens,
                    max_new_tokens=args.gen_length)

            else:

                cur_pred = model.generate(
                    **inputs,
                    do_sample=args.do_sample,
                    max_new_tokens=args.gen_length)

            prediction = code_tokenizer.decode(cur_pred[0][cut_at:], skip_special_tokens=True)

            # <|fim_pad|>, <|file_sep|>, <|fim_prefix|> are not removed by skip_special_tokens=True, manual removal
            if 'qwen' in args.text_model_id.lower():
                prediction = remove_tokens(prediction)

            all_preds.append({
                "task_id": entry["metadata"]["task_id"],
                "pred": prediction,
            })

    with open(f"{args.output_dir}/prediction.jsonl", "w", encoding="utf-8") as f_pred:
        for entry in all_preds:
            f_pred.write(json.dumps(entry) + "\n")

    if args.compute_cceval_metric:
        compute_metric_stmt_cceval(args)
    else:
        compute_metric_stmt(args)
