import torch
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
from torchvision import transforms
from transformers import AutoTokenizer, AutoModel

import io
import random
import numpy as np
import math


def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
    if sample in ['rand', 'middle']:
        acc_samples = min(num_frames, vlen)
        # split the video into `acc_samples` intervals, and sample from each interval.
        intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
        ranges = []
        for idx, interv in enumerate(intervals[:-1]):
            ranges.append((interv, intervals[idx + 1] - 1))
        if sample == 'rand':
            try:
                frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
            except:
                frame_indices = np.random.permutation(vlen)[:acc_samples]
                frame_indices.sort()
                frame_indices = list(frame_indices)
        elif fix_start is not None:
            frame_indices = [x[0] + fix_start for x in ranges]
        elif sample == 'middle':
            frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
        else:
            raise NotImplementedError

        if len(frame_indices) < num_frames:  # padded with last frame
            padded_frame_indices = [frame_indices[-1]] * num_frames
            padded_frame_indices[:len(frame_indices)] = frame_indices
            frame_indices = padded_frame_indices

    elif 'fps' in sample:  # fps0.5, sequentially sample frames at 0.5 fps
        output_fps = float(sample[3:])
        duration = float(vlen) / input_fps
        delta = 1 / output_fps  # gap between frames, this is also the clip length each frame represents
        frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
        frame_indices = np.around(frame_seconds * input_fps).astype(int)
        frame_indices = [e for e in frame_indices if e < vlen]
        if max_num_frames > 0 and len(frame_indices) > max_num_frames:
            frame_indices = frame_indices[:max_num_frames]
            # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
    elif 'interval' in sample:
        if num_frames == 1:
            frame_indices = [random.randint(0, vlen - 1)]
        else:
            # transform FPS
            interval = 8
            clip_length = num_frames * interval * input_fps / 30
            max_idx = max(vlen - clip_length, 0)
            start_idx = random.uniform(0, max_idx)
            end_idx = start_idx + clip_length - 1

            frame_indices = torch.linspace(start_idx, end_idx, num_frames)
            frame_indices = torch.clamp(frame_indices, 0, vlen - 1).long().tolist()
    else:
        raise ValueError
    return frame_indices


def get_frame_indices_start_end(num_frames, vlen, fps, start_time, end_time):
    start_idx = max(int(fps * start_time), 0) if start_time is not None and not math.isnan(start_time) else 0
    end_idx = min(int(fps * end_time), vlen) if end_time is not None and not math.isnan(end_time) else vlen
    clip_len = end_idx - start_idx

    acc_samples = min(num_frames, clip_len)
    # split the video into `acc_samples` intervals, and sample from each interval.
    intervals = np.linspace(start=start_idx, stop=end_idx, num=acc_samples + 1).astype(int)
    ranges = []
    for idx, interv in enumerate(intervals[:-1]):
        ranges.append((interv, intervals[idx + 1] - 1))

    try:
        frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
    except:
        frame_indices = np.random.permutation(list(range(start_idx, end_idx)))[:acc_samples]
        frame_indices.sort()
        frame_indices = list(frame_indices)

    if len(frame_indices) < num_frames:  # padded with last frame
        padded_frame_indices = [frame_indices[-1]] * num_frames
        padded_frame_indices[:len(frame_indices)] = frame_indices
        frame_indices = padded_frame_indices

    return frame_indices


def read_frames_decord(
    video_path, width=None, height=None,
    num_frames=8, sample='rand', fix_start=None,
    max_num_frames=-1, start_time=None, end_time=None
):
    import decord
    decord.bridge.set_bridge('torch')
    if video_path.lower().endswith('.webm'):
        # a workaround for webm, large/auto num_threads will cause error.
        num_threads = 2
    else:
        num_threads = 0

    if width is not None and height is not None:
        video_reader = decord.VideoReader(video_path, width=width, height=height, num_threads=num_threads)
    else:
        video_reader = decord.VideoReader(video_path, num_threads=num_threads)
    vlen = len(video_reader)
    fps = video_reader.get_avg_fps()
    if start_time and end_time:
        frame_indices = get_frame_indices_start_end(
            num_frames, vlen, fps, start_time, end_time
        )
    else:
        frame_indices = get_frame_indices(
            num_frames, vlen, sample=sample, fix_start=fix_start,
            input_fps=fps, max_num_frames=max_num_frames
        )
    frames = video_reader.get_batch(frame_indices)
    if isinstance(frames, torch.Tensor):
        frames = frames.numpy()  # (T, H, W, C), torch.uint8
    else:
        print(frames.shape)
        frames = frames.asnumpy()
    timestamp = {
        'num_frames': len(frame_indices),
        'timestamp': ', '.join([str(round(f / fps, 1)) for f in frame_indices])
    }
    return frames, timestamp


class mPLUG_Owl3(BaseModel):
    # No separate model module is required, but the dependencies must be met.
    # https://github.com/X-PLUG/mPLUG-Owl/blob/main/mPLUG-Owl3/requirements.txt
    INSTALL_REQ = True
    INTERLEAVE = True
    INSTALL_REQ_TXT = 'https://github.com/X-PLUG/mPLUG-Owl/blob/main/mPLUG-Owl3/requirements.txt'

    def __init__(self, model_path=None, **kwargs):
        assert model_path is not None
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path
        )

        self.model = AutoModel.from_pretrained(
            model_path,
            attn_implementation='sdpa',
            torch_dtype=torch.half,
            trust_remote_code=True
        )
        self.model.eval().cuda()
        self.processor = self.model.init_processor(self.tokenizer)
        self.logger = get_logger('mPLUG_Owl3')
        if self.INSTALL_REQ:
            self.logger.info(
                f'Please remember to meet the requirements first\n'
                f'Here: {self.INSTALL_REQ_TXT}'
            )

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if listinstr(['MMMU'], dataset):
            return False
        if listinstr(['MVBench', 'MMVet'], dataset):
            return True
        return False

    def save_video_into_images(self, line, num_frames=16, dataset_class=None):
        video_url = {
            'video': osp.join(line['prefix'], line['video']),
            'num_frames': num_frames,
            'bound': line.get('bound', None)
        }
        if osp.isdir(video_url['video']):
            frame_paths = []
            max_frame = len(os.listdir(video_url['video']))
            fps = 3
            if video_url['bound']:
                start, end = line['start'], line['end']
            else:
                start, end = -100000, 100000
            start_idx = max(1, round(start * fps))
            end_idx = min(round(end * fps), max_frame)
            seg_size = float(end_idx - start_idx) / num_frames
            frame_indices = np.array([
                int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
                for idx in range(num_frames)
            ])

            for frame_index in frame_indices:
                img = os.path.join(video_url['video'], f'{frame_index:05d}.jpg')
                frame_paths.append(img)

            return frame_paths

        if isinstance(video_url, dict):
            if video_url['bound']:
                start_time = line['start']
                end_time = line['end']
            else:
                start_time = None
                end_time = None
            num_frames = video_url.get('num_frames', num_frames)
            video_url = video_url['video']
        else:
            start_time = None
            end_time = None
            video_url = str(video_url)

        if not osp.exists(video_url):  # for MVBench_MP4
            video_url = osp.join(dataset_class.data_root, video_url)
        video, timestamp = read_frames_decord(
            video_url, num_frames=num_frames, sample='middle', start_time=start_time, end_time=end_time
        )

        to_pil = transforms.ToPILImage()
        frames = [to_pil(video[ti]) for ti in range(video.shape[0])]
        lmu_root = LMUDataRoot()
        frame_root = osp.join(lmu_root, 'images', dataset_class.dataset_name, 'mplug_owl3')
        frame_root = osp.join(frame_root, video_url.split('/')[-1].split('.')[0])
        os.makedirs(frame_root, exist_ok=True)
        frame_tmpl = 'frame-{}-of-{}.jpg'
        frame_paths = [osp.join(frame_root, frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
        for im, pth in zip(frames, frame_paths):
            if not osp.exists(pth):
                im.save(pth)

        return frame_paths

    # Currently same to mPLUG_Owl2
    def build_prompt(self, line, dataset=None, num_frames=16, video_llm=False):
        if not isinstance(dataset, str):
            dataset_class = dataset
            dataset = dataset_class.dataset_name
        assert dataset is None or isinstance(dataset, str)
        assert self.use_custom_prompt(dataset)
        if dataset_class.MODALITY == 'VIDEO':
            if listinstr(['MVBench'], dataset):
                tgt_path = self.save_video_into_images(line, num_frames, dataset_class)
            else:
                tgt_path = dataset_class.save_video_into_images(line, num_frames)
            if type(line['candidates']) is not list:
                line['candidates'] = eval(line['candidates'])
            for idx, c in enumerate(line['candidates']):
                line[chr(ord('A') + idx)] = c
        else:
            tgt_path = self.dump_image(line, dataset)
        question = line['question']
        if dataset == 'MMVet':
            prompt = question + '\nAnswer the question directly. '
        elif listinstr(['MCQ', 'Video-MCQ'], DATASET_TYPE(dataset)):
            options = {
                cand: line[cand]
                for cand in string.ascii_uppercase
                if cand in line and not pd.isna(line[cand])
            }
            options_prompt = ''
            for key, item in options.items():
                options_prompt += f'{key}. {item}\n'

            hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
            prompt = f'Hint: {hint}\n' if hint is not None else ''
            prompt += f'{question}\n'
            prompt += (
                f'{options_prompt}\nAnswer with the option’s letter from the given choices directly. '
                if len(options) else 'Answer the question directly. '
            )
        else:
            raise NotImplementedError

        message = [dict(type='text', value=prompt)]
        message.extend([dict(type='image', value=s) for s in tgt_path])
        return message

    def preproc_image(self, fname, dataset=None):
        from PIL import Image
        image = Image.open(fname).convert('RGB')
        # resize to max_size
        max_size = 448 * 16
        if max(image.size) > max_size and not listinstr(['MVBench'], dataset):
            w, h = image.size
            if w > h:
                new_w = max_size
                new_h = int(h * max_size / w)
            else:
                new_h = max_size
                new_w = int(w * max_size / h)
            image = image.resize((new_w, new_h), resample=Image.BICUBIC)
        return image

    def generate_inner(self, message, dataset=None):
        num_images = len([x for x in message if x['type'] == 'image'])
        assert num_images >= 0

        images = []
        prompt_full = ''

        for msg in message:
            if msg['type'] == 'image':
                images.append(msg['value'])
                prompt_full += '<|image|>'
            elif msg['type'] == 'text':
                prompt_full += msg['value']

        needed_messages = [
            {'role': 'user', 'content': prompt_full},
            {'role': 'assistant', 'content': ''}
        ]

        images = [self.preproc_image(fname, dataset) for fname in images]

        inputs = self.processor(needed_messages, images=images, videos=None, cut_enable=False)

        inputs.to('cuda')
        if listinstr(['MVBench'], dataset):
            inputs.update({
                'tokenizer': self.tokenizer,
                'max_new_tokens': 100,
                'decode_text': True,
                'do_sample': True,
                'top_k': 1,
            })
        else:
            inputs.update({
                'tokenizer': self.tokenizer,
                'max_new_tokens': 1024,
                'decode_text': True,
            })

        g = self.model.generate(**inputs)
        return g[0]
