import io

import torch
import numpy as np
import ffmpeg
import math

from PIL import Image

from extract_feature import clip


class ClipFeatureExtractor:
    def __init__(self, framerate=1 / 2, size=224, centercrop=True, model_name_or_path="ViT-B/32", device="cuda"):
        self.video_loader = VideoProcessor(framerate=framerate, size=size, centercrop=centercrop)
        print("Loading CLIP models")
        self.clip_extractor, _ = clip.load(model_name_or_path, device=device, jit=False)
        self.tokenizer = clip.tokenize
        self.video_preprocessor = Preprocessing()
        self.device = device

    @torch.no_grad()
    def encode_video(self, video_path: str, bsz=60):
        video_frames = self.video_loader.read_video_from_file(video_path)  # (T, H, W, 3)
        video_frames = self.video_preprocessor(video_frames)
        n_frames = len(video_frames)
        n_batch = int(math.ceil(n_frames / bsz))
        video_features = []
        for i in range(n_batch):
            st_idx = i * bsz
            ed_idx = (i + 1) * bsz
            _video_frames = video_frames[st_idx:ed_idx].to(self.device)
            _video_features = self.clip_extractor.encode_image(_video_frames)
            video_features.append(_video_features)
        video_features = torch.cat(video_features, dim=0)
        return video_features  # (T=#frames, d) torch tensor

    @torch.no_grad()
    def encode_text(self, text_list, bsz=60):
        n_text = len(text_list)
        n_batch = int(math.ceil(n_text / bsz))
        text_features = []
        for i in range(n_batch):
            st_idx = i * bsz
            ed_idx = (i + 1) * bsz
            encoded_texts = self.tokenizer(text_list[st_idx:ed_idx], context_length=77).to(self.device)
            output = self.clip_extractor.encode_text(encoded_texts)
            valid_lengths = (encoded_texts != 0).sum(1).tolist()
            batch_last_hidden_states = output["last_hidden_state"]
            for j, valid_len in enumerate(valid_lengths):
                text_features.append(batch_last_hidden_states[j, :valid_len])
        return text_features  # List([L_j, d]) torch tensor

    @torch.no_grad()
    def encode_text_query(self, text_list, bsz=60):
        n_text = len(text_list)
        n_batch = int(math.ceil(n_text / bsz))
        text_features = []
        for i in range(n_batch):
            st_idx = i * bsz
            ed_idx = (i + 1) * bsz
            encoded_texts = self.tokenizer(text_list[st_idx:ed_idx], context_length=77).to(self.device)
            output = self.clip_extractor.encode_text(encoded_texts)
            valid_lengths = (encoded_texts != 0).sum(1).tolist()
            # print(valid_lengths)
            batch_last_hidden_states = output["last_hidden_state"]
            batch_pooler_output = output["pooler_output"]
            for j, valid_len in enumerate(valid_lengths):
                text_features.append(dict(last_hidden_state=batch_last_hidden_states[j, :valid_len],
                                          pooler_output=batch_pooler_output[j]))
        return text_features  # List([L_j, d]) torch tensor


def convert_to_float(frac_str):
    try:
        return float(frac_str)
    except ValueError:
        try:
            num, denom = frac_str.split('/')
        except ValueError:
            return None
        try:
            leading, num = num.split(' ')
        except ValueError:
            return float(num) / float(denom)
        if float(leading) < 0:
            sign_mult = -1
        else:
            sign_mult = 1
        return float(leading) + sign_mult * (float(num) / float(denom))


class Normalize(object):

    def __init__(self, mean, std):
        self.mean = torch.FloatTensor(mean).view(1, 3, 1, 1)
        self.std = torch.FloatTensor(std).view(1, 3, 1, 1)

    def __call__(self, tensor):
        tensor = (tensor - self.mean) / (self.std + 1e-8)
        return tensor


class Preprocessing(object):

    def __init__(self):
        self.norm = Normalize(
            mean=[0.48145466, 0.4578275, 0.40821073],
            std=[0.26862954, 0.26130258, 0.27577711])

    def __call__(self, tensor):
        tensor = tensor / 255.0
        tensor = self.norm(tensor)
        return tensor


def generate_image_from_raw_data(raw_data, width, height):
    image = Image.new("RGB", (width, height))
    color = [tuple(pixel) for row in raw_data for pixel in row]
    image.putdata(color)
    # image.show()
    return image


def ffmpeg_frame_to_pil(frame_data):
    # Convert frame bytes to PIL Image
    image = Image.open(io.BytesIO(frame_data))
    return image.convert("RGB")


class VideoProcessor:
    """Pytorch video loader.
    Copied and modified from:
    https://github.com/linjieli222/HERO_Video_Feature_Extractor/blob/main/clip/video_loader.py
    """

    def __init__(
            self,
            framerate=1 / 2,
            size=224,
            centercrop=True,
    ):
        self.centercrop = centercrop
        self.size = size
        self.framerate = framerate

    def _get_video_info(self, video_path):
        probe = ffmpeg.probe(video_path)
        video_stream = next((stream for stream in probe['streams']
                             if stream['codec_type'] == 'video'), None)
        width = int(video_stream['width'])
        height = int(video_stream['height'])
        fps = math.floor(convert_to_float(video_stream['avg_frame_rate']))
        try:
            frames_length = int(video_stream['nb_frames'])
            duration = float(video_stream['duration'])
        except Exception:
            frames_length, duration = -1, -1
        info = {"duration": duration, "frames_length": frames_length,
                "fps": fps, "height": height, "width": width}
        return info

    def _get_output_dim(self, h, w):
        if isinstance(self.size, tuple) and len(self.size) == 2:
            return self.size
        elif h >= w:
            return int(h * self.size / w), self.size
        else:
            return self.size, int(w * self.size / h)

    def read_video_from_file(self, video_path):
        try:
            info = self._get_video_info(video_path)
            h, w = info["height"], info["width"]
        except Exception:
            print('ffprobe failed at: {}'.format(video_path))
            return {'video': torch.zeros(1), 'input': video_path,
                    'info': {}}
        height, width = self._get_output_dim(h, w)
        try:
            duration = info["duration"]
            fps = self.framerate
            if duration > 0 and duration < 1 / fps + 0.1:
                fps = 2 / max(int(duration), 1)
                print(duration, fps)
        except Exception:
            fps = self.framerate
        cmd = (
            ffmpeg
            .input(video_path)
            .filter('fps', fps=fps)
            .filter('scale', width, height)
        )
        if self.centercrop:
            x = int((width - self.size) / 2.0)
            y = int((height - self.size) / 2.0)
            cmd = cmd.crop(x, y, self.size, self.size)
        out, _ = (
            cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
            .run(capture_stdout=True, quiet=True)
        )
        if self.centercrop and isinstance(self.size, int):
            height, width = self.size, self.size
        video = np.frombuffer(out, np.uint8).reshape(
            [-1, height, width, 3])
        video = torch.from_numpy(video.astype('float32'))
        video = video.permute(0, 3, 1, 2)
        return video

    def extract_all_frames(self, video_path):
        frames = []

        # Open the video file
        probe = ffmpeg.probe(video_path)
        video_info = next(s for s in probe['streams'] if s['codec_type'] == 'video')
        duration = float(video_info['duration'])
        try:
            fps = self.framerate
            if duration > 0 and duration < 1 / fps + 0.1:
                fps = 2 / max(int(duration), 1)
                print(duration, fps)
        except Exception:
            fps = self.framerate

        # Extract frames at each second of the video
        for time in range(int(duration * fps)):
            frame_data, _ = (
                ffmpeg
                .input(video_path, ss=time / fps)  # ss parameter seeks to the specific time
                .output('pipe:', vframes=1, format='image2', vcodec='mjpeg')
                .run(capture_stdout=True)
            )

            # Convert frame data to PIL Image
            image = ffmpeg_frame_to_pil(frame_data)
            frames.append(image)

        return frames

    def read_raw_image_from_video_file(self, video_path):
        try:
            info = self._get_video_info(video_path)
            h, w = info["height"], info["width"]
        except Exception:
            print('ffprobe failed at: {}'.format(video_path))
            return {'video': torch.zeros(1), 'input': video_path,
                    'info': {}}
        height, width = self._get_output_dim(h, w)
        try:
            duration = info["duration"]
            fps = self.framerate
            if duration > 0 and duration < 1 / fps + 0.1:
                fps = 2 / max(int(duration), 1)
                print(duration, fps)
        except Exception:
            fps = self.framerate
        cmd = (
            ffmpeg
            .input(video_path)
            .filter('fps', fps=fps)
            .filter('scale', width, height)
        )
        if self.centercrop:
            x = int((width - self.size) / 2.0)
            y = int((height - self.size) / 2.0)
            cmd = cmd.crop(x, y, self.size, self.size)
        out, _ = (
            cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
            .run(capture_stdout=True, quiet=True)
        )
        if self.centercrop and isinstance(self.size, int):
            height, width = self.size, self.size
        video = np.frombuffer(out, np.uint8).reshape(
            [-1, height, width, 3])
        video = [generate_image_from_raw_data(frame, width, height) for frame in video]
        return video, info
