import tqdm
from qwen_vl_utils import smart_resize
class AndroidControl(Task):
    def __init__(self):
        self.METAFILE = 'example_android_control_high.jsonl'



    def generate_jobs(self, args) -> List[JobType]:

        lines = read_json(self.METAFILE)
        
        new_lines = []
        for line in tqdm(lines):
            width, height = line['width'], line['height']
            h_bar, w_bar = smart_resize(height, width, max_pixels=12800*28*28)
            for step_id, (step_message, check_pams) in enumerate(zip(line['messages_rounds'], line['step_check_pams'])):
                new_lines.append({
                    'question_id': len(new_lines),
                    'episode_id': line['eposide']['episode_id'],
                    'check_pams': check_pams,
                    'width': width,
                    'height': height,
                    'resized_width': w_bar,
                    'resized_height': h_bar,
                    'messages': step_message,
                    'step_id': step_id
                })
        for line in new_lines:
            line['eval_args'] = {
                'greedy': getattr(args, 'greedy', False),
                'top_p': getattr(args, 'top_p', 0.01),
                'top_k': getattr(args, 'top_k', 1),
                'temperature': getattr(args, 'temperature', 0.01),
                'repetition_penalty': getattr(args, 'repetition_penalty', 1.0),
                'presence_penalty': getattr(args, 'presence_penalty', 0.0),
                'out_seq_length': getattr(args, 'out_seq_length', 1024),
                'seed': getattr(args, 'seed', 1),
            }

        if self.sub_debug:
            new_lines = random.sample(new_lines, 1500)
        return new_lines 
    
    def compute_scores(self, jobs) -> Dict[str, Any]:
        outputs = [job['result']['gen'] for job in jobs]
        lines = [job for job in jobs]
        Type_match_num = 0
        Extact_match_num = 0
        click_match_num = 0
        all_click_num = 0
        error_num = 0
        for output, line in zip(outputs, lines):
            try:
                current_check_pam = line['check_pams']
                pred = output
                if '<tool_call>' in pred:
                    pred = pred.split('<tool_call>')[1]
                else:
                    pred = '{"name": "mobile_use", "arguments":'+pred.split('{"name": "mobile_use", "arguments":',1)[1]
                if '</tool_call>' in pred:
                    pred = pred.split('</tool_call>')[0]
                else:
                    pred = pred.split("<conclusion>")[0]
                    pred = pred.rsplit('}}',1)[0]+'}}'

                pred_action = json.loads(pred.strip())['arguments']

                type_match, extact_match = evaluate_android_control_action(pred_action, current_check_pam, line['width'], line['height'], line['resized_width'], line['resized_height'], pred_type ='abs_resized', gt_type='abs_resized')
                if type_match:
                    Type_match_num += 1
                
                if extact_match:
                    Extact_match_num += 1

                if extact_match and pred_action['action'] == 'click':
                    click_match_num += 1
                
                if current_check_pam['action'] == 'click':
                    all_click_num += 1
            except:
                import traceback
                traceback.print_exc()
                print(output)
                print(line)
                error_num += 1
                continue
        print('Type_match_num and Extact_match_num: ', Type_match_num, Extact_match_num, '/ all =', len(outputs))
        print('click_match_num:', click_match_num, '/ all =', all_click_num)
        print('error num', error_num)

        res = {
            'type_match_acc': Type_match_num/len(outputs)*100,
            'extact_match_acc': Extact_match_num/len(outputs)*100,
            'click_match_acc': click_match_num/all_click_num*100,
            'error_num': error_num,
        }
        
        print(json.dumps(res, indent=' '))

        scores = res

        return {
            'task_samples': len(jobs),
            **scores,
        }