import sys
import torch
from PIL import Image
import os.path as osp
import warnings
from .base import BaseModel
from ..smp import *
from huggingface_hub import snapshot_download


class OpenFlamingo(BaseModel):

    INSTALL_REQ = True
    INTERLEAVE = True

    def __init__(self,
                 name,
                 mpt_pth=None,
                 ckpt_pth=None,
                 **kwargs):

        if mpt_pth is None:
            raise ValueError(
                'Please set `mpt_pth` to the directory of MPT-7B, which is cloned from here: '
                'https://huggingface.co/mosaicml/mpt-7b. '
            )
            raise ValueError
        if ckpt_pth is None:
            raise ValueError(
                'Please set `ckpt_pth` to the openflamingo ckpt, which is the `checkpoint.pt` file downloaded '
                'from: https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b/tree/main. '
            )
        else:
            if osp.exists(ckpt_pth):
                if ckpt_pth.endswith('checkpoint.pt'):
                    pass
                elif osp.isdir(ckpt_pth):
                    ckpt_pth = osp.join(ckpt_pth, 'checkpoint.pt')
                    if not osp.exists(ckpt_pth):
                        raise ValueError(f'File {ckpt_pth} does not exist. ')
            elif splitlen(ckpt_pth, '/') == 2:
                cache_path = get_cache_path(ckpt_pth)
                if cache_path is None:
                    snapshot_download(ckpt_pth)
                cache_path = get_cache_path(ckpt_pth)
                if cache_path is None:
                    raise ValueError(f'Directory {cache_path} does not exist. ')
                else:
                    ckpt_pth = osp.join(cache_path, 'checkpoint.pt')

        self.name = name
        assert name in ['v2']
        self.mpt_pth = mpt_pth
        try:
            from open_flamingo import create_model_and_transforms
        except Exception as e:
            logging.critical('Please first install open_flamingo to use OpenFlamingo')
            raise e

        model, image_processor, tokenizer = create_model_and_transforms(
            clip_vision_encoder_path='ViT-L-14',
            clip_vision_encoder_pretrained='openai',
            lang_encoder_path=mpt_pth,
            tokenizer_path=mpt_pth,
            cross_attn_every_n_layers=4)
        ckpt = torch.load(ckpt_pth)
        model.load_state_dict(ckpt, strict=False)
        torch.cuda.empty_cache()
        self.model = model.eval().cuda()
        self.tokenizer = tokenizer
        self.tokenizer.padding_side = 'left'
        self.image_proc = image_processor

        kwargs_default = dict(max_new_tokens=512, num_beams=3)
        kwargs_default.update(kwargs)
        self.kwargs = kwargs_default
        warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')

    def generate_inner(self, message, dataset=None):
        vision_x = []
        prompt = ''
        for msg in message:
            if msg['type'] == 'image':
                img = Image.open(msg['value'])
                vision_x.append(self.image_proc(img).unsqueeze(0))
                prompt += '<image>'
            elif msg['type'] == 'text':
                prompt += msg['value']
        prompt += 'Answer: '
        vision_x = torch.cat(vision_x, dim=0) if len(vision_x) > 1 else vision_x[0]
        vision_x = vision_x.unsqueeze(1).unsqueeze(0)
        lang_x = self.tokenizer([prompt], return_tensors='pt')
        generated_text = self.model.generate(
            vision_x=vision_x.cuda(),
            lang_x=lang_x['input_ids'].cuda(),
            attention_mask=lang_x['attention_mask'].cuda(),
            **self.kwargs)
        generated_text = self.tokenizer.decode(generated_text[0])
        text = generated_text[len(prompt):].split('<|endofchunk|>')[0]
        return text
