import sys
import torch
import os.path as osp
import warnings
from .base import BaseModel
from ..smp import *


class PandaGPT(BaseModel):

    INSTALL_REQ = True
    INTERLEAVE = False

    def __init__(self, name, root=None, **kwargs):
        if root is None:
            raise ValueError('Please set `root` to PandaGPT code directory, which is cloned from here: ')

        assert name == 'PandaGPT_13B'
        self.name = name
        sys.path.append(osp.join(root, 'code'))
        try:
            from model.openllama import OpenLLAMAPEFTModel
        except Exception as e:
            logging.critical(
                'Please first install PandaGPT and set the root path to use PandaGPT, '
                'which is cloned from here: https://github.com/yxuansu/PandaGPT. '
            )
            raise e

        self.args = {
            'model': 'openllama_peft',
            'imagebind_ckpt_path': osp.join(root, 'pretrained_ckpt/imagebind_ckpt'),
            'vicuna_ckpt_path': osp.join(root, 'pretrained_ckpt/vicuna_ckpt/13b_v0'),
            'delta_ckpt_path': osp.join(root, 'pretrained_ckpt/pandagpt_ckpt/13b/pytorch_model.pt'),
            'stage': 2,
            'max_tgt_len': 512,
            'lora_r': 32,
            'lora_alpha': 32,
            'lora_dropout': 0.1,
        }
        model = OpenLLAMAPEFTModel(**self.args)
        delta_ckpt = torch.load(self.args['delta_ckpt_path'], map_location=torch.device('cpu'))
        model.load_state_dict(delta_ckpt, strict=False)
        torch.cuda.empty_cache()
        self.model = model.eval().half().cuda()
        kwargs_default = {'top_p': 0.9, 'do_sample': False, 'max_tgt_len': 128, 'temperature': 0.001}
        kwargs_default.update(kwargs)
        self.kwargs = kwargs_default
        warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')

    def generate_inner(self, message, dataset=None):
        prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
        struct = {
            'prompt': prompt,
            'image_paths': [image_path],
            'audio_paths': [],
            'video_paths': [],
            'thermal_paths': [],
            'modality_embeds': []
        }
        struct.update(self.kwargs)
        resp = self.model.generate(struct)
        return resp
