import sys
sys.path.append('..')

from tqdm import tqdm
import jsonlines

from utils import RESPONSE_DICT

class Mllm:
    
    def __init__(self, model_name_or_path, *args, **kwargs) -> None:
        pass
    
    def evaluate(self, prompt, filepath):
        pass
    
    def batch_evaluate(self, args, data):
        response_list = []
        for sample in tqdm(data):
            prompt = sample['prompt']
            image = sample['img_url']
            res = sample.copy()
            # res = RESPONSE_DICT.copy()
            # res['prompt'] = prompt
            # res['img_url'] = image
            # res['lan'] = sample['lan']
            # if 'ground_truth' in sample.keys():
            #     res['ground_truth'] = sample['ground_truth']
            
            # try:
            #     response = self.evaluate(prompt, image)
            #     res['response'] = response
            # except Exception as e:
            #     print(f'Image{image} Error: {e}')
            #     res['response'] = 'Error'
                
            response = self.evaluate(prompt, image)
            res['response'] = response
                
            if args.verbose:
                print(res)
            response_list.append(res)
        
        with jsonlines.open(args.save_path, 'w') as writer:
            writer.write_all(response_list)
    
    def batch_evaluate_of_caption(self, args, data):
        response_list = []
        for sample in tqdm(data):
            prompt = sample['prompt']
            image = sample['img_url']
            caption = sample['caption']
            res = sample.copy()
                
            response = self.evaluate_of_caption(prompt, image, caption)
            res['response'] = response
                
            if args.verbose:
                print(res)
            response_list.append(res)
        
        with jsonlines.open(args.save_path, 'w') as writer:
            writer.write_all(response_list)
    
    def batch_evaluate_of_caption_img(self, args, data):
        response_list = []
        for sample in tqdm(data):
            prompt = sample['prompt']
            image = sample['img_url']
            caption = sample['caption']
            res = sample.copy()
                
            response = self.evaluate_of_caption_img(prompt, image, caption)
            res['response'] = response
                
            if args.verbose:
                print(res)
            response_list.append(res)
        
        with jsonlines.open(args.save_path, 'w') as writer:
            writer.write_all(response_list)
    
    def batch_evaluate_with_intervention(self, args, data, interventions={}, intervention_fn=None, multiple=False):
        response_list = []
        for sample in tqdm(data):
            prompt = sample['prompt']
            image = sample['img_url']
            res = sample.copy()
            
            if multiple == False:
                response = self.evaluate_with_intervention(prompt, image, interventions, intervention_fn)
            else:
                response = self.evaluate_with_multiple_intervention(prompt, image, interventions, intervention_fn)
            res['response'] = response
                        
            if args.verbose:
                print(res)
            response_list.append(res)
        
        with jsonlines.open(args.save_path, 'w') as writer:
            writer.write_all(response_list)

    def batch_evaluate_with_intervention_youare(self, args, data, interventions={}, intervention_fn=None, multiple=False):
        response_list = []
        for sample in tqdm(data):
            prompt = sample['prompt']
            image = sample['img_url']
            res = sample.copy()
            
            if multiple == False:
                response = self.evaluate_with_intervention_youare(prompt, image, interventions, intervention_fn)
            else:
                response = self.evaluate_with_multiple_intervention(prompt, image, interventions, intervention_fn)
            res['response'] = response
                        
            if args.verbose:
                print(res)
            response_list.append(res)
        
        with jsonlines.open(args.save_path, 'w') as writer:
            writer.write_all(response_list)
            
    def batch_evaluate_with_intervention_youare_offset(self, args, data, interventions={}, intervention_fn=None, multiple=False):
        response_list = []
        for sample in tqdm(data):
            prompt = sample['prompt']
            image = sample['img_url']
            res = sample.copy()
            
            if multiple == False:
                response = self.evaluate_with_intervention_youare_offset(prompt, image, interventions, intervention_fn)
            else:
                response = self.evaluate_with_multiple_intervention(prompt, image, interventions, intervention_fn)
            res['response'] = response
                        
            if args.verbose:
                print(res)
            response_list.append(res)
        
        with jsonlines.open(args.save_path, 'w') as writer:
            writer.write_all(response_list)

    def batch_evaluate_with_i2t(self, args, data, intervention):
        response_list = []
        for sample in tqdm(data):
            prompt = sample['prompt']
            image = sample['img_url']
            res = sample.copy()
            # res = RESPONSE_DICT.copy()
            # res['prompt'] = prompt
            # res['img_url'] = image
            # res['lan'] = sample['lan']
            # if 'ground_truth' in sample.keys():
            #     res['ground_truth'] = sample['ground_truth']
                        
            response = self.evaluate_with_i2t(prompt, image, intervention)
            res['response'] = response
                        
            if args.verbose:
                print(res)
            response_list.append(res)
        
        with jsonlines.open(args.save_path, 'w') as writer:
            writer.write_all(response_list)
    
    def batch_evaluate_with_intervention2(self, args, data, interventions={}, intervention_fn=None, multiple=False):
        response_list = []
        for sample in tqdm(data):
            prompt = sample['prompt']
            image = sample['img_url']
            res = RESPONSE_DICT.copy()
            res['prompt'] = prompt
            res['img_url'] = image
            res['lan'] = sample['lan']
            
            # if multiple == False:
            #     response = self.evaluate_with_intervention(prompt, image, interventions, intervention_fn)
            # else:
            #     response = self.evaluate_with_multiple_intervention(prompt, image, interventions, intervention_fn)
            # res['response'] = response
            
            attempt = 0
            while attempt < 10:
                try:
                    if multiple == False:
                        response = self.evaluate_with_intervention(prompt, image, interventions, intervention_fn)
                    else:
                        response = self.evaluate_with_multiple_intervention2(prompt, image, interventions, intervention_fn)
                    res['response'] = response
                    attempt = 10
                except Exception as e:
                    print(f'Image{image} Error: {e}')
                    res['response'] = 'Error'
                    attempt += 1
                        
            if args.verbose:
                print(res)
            response_list.append(res)
        
        with jsonlines.open(args.save_path, 'w') as writer:
            writer.write_all(response_list)