import torch
import sys
import os.path as osp
import warnings
from PIL import Image
from vlmeval.smp import get_cache_path, load, dump, splitlen
from huggingface_hub import snapshot_download
from .base import BaseModel


"""
You can perform inference of Yi-VL through the following steps:
1. clone the repo https://github.com/01-ai/Yi to path-to-Yi
2. set up the environment and install the required packages in path-to-Yi/VL/requirements.txt
3. set Yi_ROOT in vlmeval/config.py
    Yi_ROOT = path-to-Yi

You are all set now! To run a demo for Yi-VL:
```python
from vlmeval import *
model = supported_VLM['Yi_VL_6B']()
model.generate('apple.jpg', 'What is in this image?')
```
To run evaluation for Yi-VL, use `python run.py --model Yi_VL_6B --data {dataset_list}`
"""


def edit_config(repo_id):
    if not osp.exists(repo_id):
        root = get_cache_path(repo_id)
    else:
        root = repo_id
    assert root is not None and osp.exists(root)
    cfg = osp.join(root, 'config.json')
    data = load(cfg)
    mm_vision_tower = data['mm_vision_tower']
    if mm_vision_tower.startswith('./vit/'):
        data['mm_vision_tower'] = osp.join(root, mm_vision_tower)
        assert osp.exists(data['mm_vision_tower'])
        dump(data, cfg)


def disable_torch_init():
    """
    Disable the redundant torch default initialization to accelerate model creation.
    """
    import torch

    setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
    setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)


class Yi_VL(BaseModel):

    INSTALL_REQ = True
    INTERLEAVE = False

    def __init__(self,
                 model_path='01-ai/Yi-VL-6B',
                 root=None,
                 **kwargs):

        if root is None:
            warnings.warn(
                'Please set root to the directory of Yi, '
                'which is cloned from here: https://github.com/01-ai/Yi.'
            )

        self.root = osp.join(root, 'VL')
        sys.path.append(self.root)

        if splitlen(model_path, '/') == 2 and not osp.exists(model_path):
            if get_cache_path(model_path) is None:
                snapshot_download(repo_id=model_path)
            edit_config(model_path)
        elif osp.exists(model_path):
            edit_config(model_path)

        from llava.mm_utils import get_model_name_from_path, load_pretrained_model
        from llava.model.constants import key_info

        disable_torch_init()
        key_info['model_path'] = model_path
        get_model_name_from_path(model_path)
        self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
            model_path,
            device_map='cpu')
        self.model = self.model.cuda()
        self.conv_mode = 'mm_default'

        kwargs_default = dict(temperature=0.2,
                              num_beams=1,
                              do_sample=False,
                              max_new_tokens=1024,
                              top_p=None)
        kwargs_default.update(kwargs)
        self.kwargs = kwargs_default
        warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')

    def generate_inner(self, message, dataset=None):
        prompt, image_path = self.message_to_promptimg(message, dataset=dataset)

        from llava.conversation import conv_templates
        from llava.model.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
        from llava.mm_utils import KeywordsStoppingCriteria, expand2square, tokenizer_image_token

        qs = DEFAULT_IMAGE_TOKEN + '\n' + prompt
        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = (
            tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
            .unsqueeze(0)
            .cuda()
        )

        image = Image.open(image_path)
        if getattr(self.model.config, 'image_aspect_ratio', None) == 'pad':
            if image.mode == 'L':
                background_color = int(sum([int(x * 255) for x in self.image_processor.image_mean]) / 3)
            else:
                background_color = tuple(int(x * 255) for x in self.image_processor.image_mean)
            image = expand2square(image, background_color)
        image_tensor = self.image_processor.preprocess(image, return_tensors='pt')[
            'pixel_values'
        ][0]

        stop_str = conv.sep
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
        self.model = self.model.to(dtype=torch.bfloat16)
        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                images=image_tensor.unsqueeze(0).to(dtype=torch.bfloat16).cuda(),
                stopping_criteria=[stopping_criteria],
                use_cache=True,
                **self.kwargs)

        input_token_len = input_ids.shape[1]
        n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
        if n_diff_input_output > 0:
            print(
                f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids'
            )
        outputs = self.tokenizer.batch_decode(
            output_ids[:, input_token_len:], skip_special_tokens=True
        )[0]
        outputs = outputs.strip()

        if outputs.endswith(stop_str):
            outputs = outputs[: -len(stop_str)]
        outputs = outputs.strip()
        return outputs
