from abc import ABC, abstractmethod
from typing import List, Dict, Any, Union

import torch
import gc
import requests
import uuid
import json
from video_utils.probe_video import probe_video

class BaseModelProcessor(ABC):

    def __init__(self, model_path: str):
        self.model_path = model_path

    @abstractmethod
    def initialize_model(self) -> None:
        pass

    @abstractmethod
    def generate(self, messages: List[Dict[str, Any]], **kwargs) -> str:
        pass

    @abstractmethod
    def batch_generate(self, messages_list: List[List[Dict[str, Any]]], **kwargs) -> List[str]:
        pass

    @abstractmethod
    def new_session(self, chunk: Union[None, Dict[str, Any]] = None, **kwargs) -> str:
        pass

    @abstractmethod
    def stream(self, chunk: Dict[str, Any], **kwargs) -> str:
        pass

    def lightweight_gpu_reset(self):
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
            torch.cuda.init()
            print(f"GPU memory_allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB")

class Truncate(BaseModelProcessor, ABC):
    def __init__(self, model_path, max_video_length_in_seconds=30.0, **kwargs):
        super().__init__(model_path)
        self.max_video_length_in_seconds = max_video_length_in_seconds

    def new_session(self, chunk: Union[None, Dict[str, Any]] = None, **kwargs) -> str:
        self.session_id = uuid.uuid4()
        self.max_video_length_in_seconds = self.max_video_length_in_seconds
        self.cum_video_length_in_seconds = 0.0
        self.video_chunk_info = {}
        self.context = []
        self.current_context_video_info = []
        self.silent_helper = kwargs.get("silent_helper", lambda _ : False)

        if chunk:
            self.add_chunk(chunk=chunk)

        return self.session_id

    def stream(self, chunk: Dict[str, Any], **kwargs) -> str:
        self.add_chunk(chunk=chunk)

        if kwargs.get("skip_response", False):
            return ""

        response = self.generate(self.context, **kwargs)
        
        if not self.silent_helper(response):
            self.add_chunk({
                "role": "assistant",
                "content": [{"type": "text", "text": response}]
            })

        return response
    
    def add_chunk(self, chunk):
        new_content = []

        for ele in chunk["content"]:
            if ele["type"]=="video":
                v = probe_video(ele["video"])
                width = v.get("video", {}).get("width", 1920)
                height = v.get("video", {}).get("height", 1080)
                num_frames = v.get("video", {}).get("total_frames", 0)
                duration = v.get("duration", 0)
                self.video_chunk_info[ele["video"]] = v

                width = max(128, min(448, width))
                height = max(128, min(448, height))
                new_content.append({
                    "type": "video", 
                    "video": ele["video"], 
                    "max_frames": num_frames,
                    "max_pixels": 384 * 28 * 28,
                    "fps": 2.0
                })
                self.cum_video_length_in_seconds += duration
            else:
                new_content.append(ele)
        
        self.context.append({
            "role": chunk["role"],
            "content": new_content,
        })
                
        chunk_idx = 0
        ele_idx = 0
        while self.cum_video_length_in_seconds > self.max_video_length_in_seconds:
            if chunk_idx >= len(self.context):
                break
            if ele_idx >= len(self.context[chunk_idx]["content"]):
                ele_idx = 0
                chunk_idx += 1
                continue

            ele = self.context[chunk_idx]["content"][ele_idx]
            if ele["type"]=="video":
                dur = self.video_chunk_info[ele["video"]]["duration"]

                self.cum_video_length_in_seconds -= dur

                ct = self.context[chunk_idx]["content"]
                ct = ct[0:ele_idx] + ct[ele_idx+1::]
                if ct:
                    self.context[chunk_idx]["content"] = ct
                else:
                    self.context = self.context[0: chunk_idx] + self.context[chunk_idx+1::]
                    ele_idx = 0
            else:
                ele_idx += 1
        
        self.current_context_video_info = []
        for msg in self.context:
            for ele in msg["content"]:
                if ele["type"]=="video":
                    self.current_context_video_info.append(
                        self.video_chunk_info[ele["video"]] | {
                            "fps": ele["fps"]
                        }
                    )
