import torch
from models.gpt4video import GPT4Video as model
from dataset.data_helper import FieldParser
from configs.config_inference import parser
from models.text2video_zero.text2video_zero import Text2Video_Zero
from models.videofusion.text2video_ali import VideoFusion, Zeroscope
from models.VideoCrafter1.scripts.gradio.t2v_test import Text2Video
import gradio as gr


class GPT4Video:
    def __init__(self):
        self.args = parser.parse_args(args=[])
        self.args.delta_file="GPT4Video/save_model/improve_v4/checkpoints/checkpoint_epoch10_step2090_val_loss0.888093.pth"
        self.device_video = "cuda:0"
        self.device = "cuda:1"

        self.model = model(self.args).to(self.device)
        self.tokenizer = self.model.tokenizer
        self.model.eval()

        self.parser = FieldParser(self.args)

        # # Loading video generation models
        # self.videocrafter = GenVideo() # videocrafter
        self.videocrafter1 = Text2Video() # videocrafter1
        self.videofusion = VideoFusion(device=self.device_video) # videofusion
        self.zeroscope = Zeroscope(device=self.device_video) # Zeroscope
        self.t2vzero = Text2Video_Zero(device=self.device_video) #Text2video_zero

        prefix = self.tokenizer("The following is a conversation between a curious human and AI assistant GPT4Video. GPT4Video generates video prompts at the most appropriate time and gives helpful, detailed, and polite answers to the user's questions.\n", return_tensors="pt", add_special_tokens=False)
        prefix_ids = prefix.input_ids[0]
        prefix_ids = prefix_ids.to(self.device)
        prefix_embs = self.model.input_embeddings(prefix_ids)
        self.prefix_embs = prefix_embs.unsqueeze(0)
        
        self.text_cache = {}
        self.video_cache = {}


    def add_video_idx(self, cnt, video_emb):
        out = self.tokenizer(f"this is video {str(cnt)}", return_tensors="pt", add_special_tokens=False)
        input_ids = out.input_ids[0]
        input_ids = input_ids.to(self.device)
        text_embs = self.model.input_embeddings(input_ids)
        text_embs = text_embs.unsqueeze(0)
        video_emb = torch.cat([text_embs, video_emb], dim=1)
        return video_emb
    

    def get_prompt_embed(self, history):
        cnt = 0
        input_embeds = [self.prefix_embs]
        for prompt in history:
            for pro in prompt:
                for p in pro:
                    if p is None:
                        continue
                    if p.endswith("mp4"):
                        cnt += 1
                        if p not in self.video_cache:
                            video = self.parser.processor.process_video([p]).to(self.device)
                            video_embs = self.model.video_encoder(video).to(self.device)
                            video_embs = self.add_video_idx(cnt, video_embs)
                            self.video_cache[p] = video_embs
                        input_embeds.append(self.video_cache[p])
                    elif isinstance(p, str):
                        if p not in self.text_cache:
                            out = self.tokenizer(p, return_tensors="pt", add_special_tokens=False)
                            input_ids = out.input_ids[0].to(self.device)
                            text_embs = self.model.input_embeddings(input_ids).unsqueeze(0)
                            self.text_cache[p] = text_embs
                        input_embeds.append(self.text_cache[p])
        
        out = self.tokenizer('</s> \nAI:', return_tensors="pt", add_special_tokens=False)
        input_ids = out.input_ids[0].to(self.device)
        text_embs = self.model.input_embeddings(input_ids).unsqueeze(0)
        input_embeds.append(text_embs)
        embeddings = torch.cat(input_embeds, 1).to(self.device)
        return embeddings


    def gen_video(self, video_decoder, num_frames, fps, prompt):
        valid_video_decoders = ["Text2Video-Zero", "VideoCrafter", "VideoFusion", "Zeroscope", "VideoCrafter1"]
        if video_decoder not in valid_video_decoders:
            raise ValueError(f"Invalid video_decoder: {video_decoder}. Must be one of {valid_video_decoders}")
    
        if video_decoder == "Text2Video-Zero":
            gen_video_path = self.t2vzero.generate_video_from_text(prompt=prompt, num_frames=num_frames, fps=fps)
        if video_decoder == "VideoCrafter":
            gen_video_path = self.videocrafter.generate_video_from_text(prompt=prompt, num_frames=num_frames, fps=fps)
        if video_decoder == "VideoFusion":
            gen_video_path = self.videofusion.generate_video_from_text(prompt=prompt, num_frames=num_frames, fps=fps)
        if video_decoder == "Zeroscope":
            gen_video_path = self.zeroscope.generate_video_from_text(prompt=prompt, num_frames=num_frames, fps=fps)
        if video_decoder == "VideoCrafter1":
            gen_video_path = self.videocrafter1.generate_video_from_text(prompt=prompt, frames=num_frames, fps=fps)
        return gen_video_path


    def run(self, video_decoder, history, num_frames=16, fps=8, max_turns=3, temperature=0.0, max_len=50, top_p=1.0):
        embeddings = self.get_prompt_embed(history)
        if embeddings.shape[1] > max_len:
            embeddings = embeddings[:, -max_len:]
            # raise gr.Error(f'The accumulated input is too long ({embeddings.shape[1]} > {max_len}). Clear your chat history and try again.')
        
        if len(history) > max_turns:
            self.text_cache = {}
            self.video_cache = {}
            raise gr.Error('Error: History exceeds maximum rounds, please clear history and restart.')

        first_flag = True
        gen_video_flag = False
        hold_flag = True
        with torch.no_grad():
            out = None
            out_2 = None
            prompt = None
            prompt_text = None
            for i in range(120):
                output = self.model.model.language_model(inputs_embeds=embeddings, use_cache=False, output_hidden_states=True)
                logits = output.logits[:, -1, :]  # (N, vocab_size)
                next_token = torch.argmax(logits, keepdim=True, dim=-1)  # (N, 1)
                next_token = next_token.long().to(embeddings.device)

                # First part
                if out is None:
                    out = next_token
                elif out is not None and self.tokenizer.decode(next_token[0]) != '<' and first_flag:
                    out = torch.cat([out, next_token], dim=-1)
                else:
                    first_flag = False
                
                # Second part
                if not first_flag and not gen_video_flag:
                    if self.tokenizer.decode(next_token[0]) == "</":
                        prompt_text = self.tokenizer.decode(prompt[0])
                        prompt_text = prompt_text.replace("<video>", "")
                        video_path = self.gen_video(video_decoder, num_frames, fps, prompt_text)
                        gen_video_flag = True

                    if prompt is None:
                        prompt = next_token
                    else:
                        prompt = torch.cat([prompt, next_token], dim=-1)
                        prompt_text = self.tokenizer.decode(prompt[0])
                        print(prompt_text)
                
                if not first_flag and gen_video_flag:
                    if self.tokenizer.decode(next_token[0]) not in ["</", "video", ">"] and hold_flag:
                        hold_flag = False
                        if out_2 is None:
                            out_2 = next_token
                        else:
                            out_2 = torch.cat([out_2, next_token], dim=-1)


                next_embedding = self.model.input_embeddings(next_token)
                embeddings = torch.cat([embeddings, next_embedding], dim=1)

                if self.tokenizer.decode(next_token[0]) in ["</s>", "Human"]:
                    print(f"break at {i} with {self.tokenizer.decode(next_token[0])}")
                    break
                
                if not gen_video_flag:
                    out_text = self.tokenizer.decode(out[0])
                    out_text = out_text.replace('AI:', '')
                    print(out_text)
                    yield (out_text, None, None, None)
                else:
                    if out_2 is not None:
                        out_text_2 = self.tokenizer.decode(out_2[0])
                        print(out_text_2)
                        yield (out_text, video_path, out_text_2, prompt_text)
                    else:
                        yield (out_text, video_path, None, prompt_text)
        return