# 4merge_result.py
import json
import srsly
import fire
from tqdm import tqdm
from utils import evaluate_response

def load_and_validate_data(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        if isinstance(data, dict) and 'predictions' in data:
            return data['predictions']
        elif isinstance(data, list):
            return data
        else:
            raise ValueError(f"Unsupported data format in {file_path}")
            
    except Exception as e:
        raise RuntimeError(f"Error loading {file_path}: {e}")

def main(path1, path2, output_path):
    try:
        print(f"Loading data from {path1}")
        result_1b = load_and_validate_data(path1)
        print(f"Loaded {len(result_1b)} samples from 1.5B model")
        
        print(f"Loading data from {path2}")
        result_7b = load_and_validate_data(path2)
        print(f"Loaded {len(result_7b)} samples from 7B model")
        
        if len(result_1b) != len(result_7b):
            raise ValueError(f"Mismatched number of samples: {len(result_1b)} vs {len(result_7b)}")
        
        final_combine_res = []
        
        progress_interval = 100  
        for idx, (item1, item2) in enumerate(tqdm(zip(result_1b, result_7b), desc="Processing responses", total=len(result_1b))):
            required_fields = ['question', 'slm_response']
            for field in required_fields:
                if field not in item1 or field not in item2:
                    print(f"Warning: Missing field {field} in sample")
                    continue
            
            try:
                if 'risk_assessment' in item1 and 'risk_assessment' in item2:
                    risk1 = item1['risk_assessment']['transformed_risk'] if isinstance(item1['risk_assessment'], dict) else item1['risk_assessment']
                    risk2 = item2['risk_assessment']['transformed_risk'] if isinstance(item2['risk_assessment'], dict) else item2['risk_assessment']
                    selected_result = item2 if risk1 > risk2 else item1
                    selection_reason = {
                        'model_1b_risk': float(risk1),
                        'model_7b_risk': float(risk2),
                        'selected_model': '7b' if risk1 > risk2 else '1b'
                    }
                else:
                    pred1 = float(item1.get('pred', 1.0))
                    pred2 = float(item2.get('pred', 1.0))
                    selected_result = item2 if pred1 < pred2 else item1
                    selection_reason = {
                        'model_1b_pred': pred1,
                        'model_7b_pred': pred2,
                        'selected_model': '7b' if pred1 < pred2 else '1b'
                    }
                human_eval = evaluate_response(
                    selected_result['question'], 
                    selected_result['slm_response']
                )
                result_item = selected_result.copy()
                result_item.update({
                    'human_evaluation': human_eval,
                    'selection_reason': selection_reason,
                    'original_responses': {
                        '1b': item1['slm_response'],
                        '7b': item2['slm_response']
                    }
                })
                
                final_combine_res.append(result_item)
                

                if (idx + 1) % progress_interval == 0:
                    print(f"\nProcessed {idx + 1}/{len(result_1b)} items")
                    recent_results = final_combine_res[-progress_interval:]
                    safe_count = sum(1 for item in recent_results if item['human_evaluation'] == 1)
                    unsafe_count = sum(1 for item in recent_results if item['human_evaluation'] == 0)
                    error_count = sum(1 for item in recent_results if item['human_evaluation'] == -1)
                    
                    print(f"Last {progress_interval} items statistics:")
                    print(f"Safe: {safe_count}, Unsafe: {unsafe_count}, Errors: {error_count}")
                    print(f"Safe ratio: {safe_count/progress_interval:.2%}")
                    
            except Exception as e:
                print(f"Error processing sample {idx}: {e}")
                continue
                
            except Exception as e:
                print(f"Error processing sample: {e}")
                continue
        statistics = {
            'total_samples': len(final_combine_res),
            'safe_count': sum(1 for item in final_combine_res if item['human_evaluation'] == 1),
            'unsafe_count': sum(1 for item in final_combine_res if item['human_evaluation'] == 0),
            'error_count': sum(1 for item in final_combine_res if item['human_evaluation'] == -1),
            'model_selection': {
                '1b': sum(1 for item in final_combine_res if item['selection_reason']['selected_model'] == '1b'),
                '7b': sum(1 for item in final_combine_res if item['selection_reason']['selected_model'] == '7b')
            }
        }

        output_data = {
            'statistics': statistics,
            'results': final_combine_res
        }
        
        with open(output_path, "w", encoding="utf-8") as file:
            json.dump(output_data, file, ensure_ascii=False, indent=4)

        
        print("\nProcessing Complete!")
        print(f"Total samples processed: {statistics['total_samples']}")
        print(f"Safe responses: {statistics['safe_count']}")
        print(f"Unsafe responses: {statistics['unsafe_count']}")
        print(f"Evaluation errors: {statistics['error_count']}")
        print(f"Model selection - 1B: {statistics['model_selection']['1b']}, 7B: {statistics['model_selection']['7b']}")
        
    except Exception as e:
        print(f"Error in main process: {e}")
        if final_combine_res:
            backup_path = output_path + ".backup"
            with open(backup_path, "w", encoding="utf-8") as file:
                json.dump(final_combine_res, file, ensure_ascii=False, indent=4)
            print(f"Partial results saved to: {backup_path}")
        raise

if __name__ == "__main__":
    fire.Fire(main)