import argparse
import os
import copy
import os.path as osp
import time
from collections import defaultdict

import numpy as np
import pandas as pd
import torch

from accelerate import Accelerator
from accelerate.utils import gather_object
from PIL import Image
from tqdm import tqdm

import decord # fake import of a missing but unused package to prevent error.

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

class MPLUG(torch.nn.Module):
    def __init__(self, ckpt='', device='gpu'):
        super().__init__()
        self.device = str(device)
        self.pipeline_vqa = pipeline(Tasks.visual_question_answering, model=ckpt, device=self.device)

    def vqa(self, image, question):
        input_vqa = {'image': image, 'question': question}
        result = self.pipeline_vqa(input_vqa)
        return result['text']
    
    def evaluate_one_sample(self, value, generated_image):
        
        qid2question = value['qid2question']
        qid2dependency = value['qid2dependency']

        qid2answer = dict()
        qid2scores = dict()
        qid2validity = dict()

        for id, question in qid2question.items():
            answer = self.vqa(generated_image, question)
            qid2answer[id] = answer
            qid2scores[id] = float(answer == 'yes')
            
        
        for id, parent_ids in qid2dependency.items():
            # zero-out scores if parent questions are answered 'no'
            any_parent_answered_no = False

            for parent_id in parent_ids:
                if parent_id == 0:
                    continue
                
                if qid2scores[parent_id] == 0:
                    any_parent_answered_no = True
                    break
            if any_parent_answered_no:
                qid2scores[id] = 0
                qid2validity[id] = False
            else:
                qid2validity[id] = True

        score = sum(qid2scores.values()) / len(qid2scores)
    
        return score