"""
step4_update_model_params.py
This script is used to update model parameters.
"""

import sys
import argparse
import json
import os
import time
from tqdm import tqdm
from typing import List, Dict, Any, Optional, Tuple, Callable
import logging

from train_utils import load_large_data, save_json, get_solver, get_prover
from logger import setup_logger

logger = logging.getLogger(__name__)

def preprocess_data(
    data: List[Dict[str, Any]], 
    alpha: float = 0.2
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    Preprocess the data: extract state and hypothesis formal proofs.
    """
    counterex_sol_data, formalproof_gen_data = [], []
    for problem in data:
        for item in problem["results"]:
            hyp_sol_score, state_sol_score = float(item.get("hyp_passed_rate", 0.0)), float(item.get("state_passed_rate", 0.0))
            if hyp_sol_score == 0.0 and state_sol_score == 0.0:
                continue
            
            # counterexample solver
            weight = hyp_sol_score * alpha + state_sol_score * (1 - alpha)
            counterex_sol_data.append(    
                {
                    "instruction": item["solver_input"],
                    "input": "",
                    "output": item["solver_outputs"],
                    "weight": weight,
                }
            )
            
            # hyp formal proof
            hyp_proof_input = item["hyp_prover_input"]
            for (passed, proof) in zip(item.get("hyp_passed", []), item.get("hyp_prover_outputs", [])):
                if not passed:
                    continue
                
                formalproof_gen_data.append(
                    {
                        "instruction": hyp_proof_input,
                        "input": "",
                        "output": proof,
                        "weight": alpha * (1 - passed),
                    }
                )
                
            # state formal proof
            state_proof_input = item["state_prover_input"]
            for (passed, proof) in zip(item.get("state_passed", []), item.get("state_prover_outputs", [])):
                if not passed:
                    continue
                
                formalproof_gen_data.append(
                    {
                        "instruction": state_proof_input,
                        "input": "",
                        "output": proof,
                        "weight": (1 - alpha) * (1 - passed),
                    }
                )
    
    logger.info(f"Preprocessed {len(counterex_sol_data)} counterexample solver problems and {len(formalproof_gen_data)} formal proof generator problems.")
    return counterex_sol_data, formalproof_gen_data


def update_model_params(
    training_data: List[Dict[str, Any]], 
    model:str="solver" or "prover"
):
    """
    Update model parameters.
    """
    # update dataset to /data3/Anony/Erdos-Prover/LLaMA-Factory/data/train_data.json
    if model == "solver":
        dataset_path = "./LLaMA-Factory/data/solver_train_data.json"
        bash_cmd = "bash train_scripts/solver_train.sh"
    elif model == "prover":
        dataset_path = "./LLaMA-Factory/data/prover_train_data.json"
        bash_cmd = "bash train_scripts/prover_train.sh"
    else:
        raise ValueError(f"Invalid model: {model}")
        
    with open(dataset_path, "w") as f:
        json.dump(training_data, f, indent=2)
    logger.info(f"Updated dataset to {dataset_path}")
    logger.info(f"Length of dataset: {len(training_data)}")
    # run the bash command
    os.system(bash_cmd)
    logger.info(f"Updated model {model} parameters")
    

def main():
    parser = argparse.ArgumentParser(description="Update model parameters.")
    parser.add_argument("--input_file", type=str, required=True, help="Path to the input JSON file")
    parser.add_argument("--solver_name", type=str, required=True, help="Name of the solver")
    parser.add_argument("--prover_name", type=str, required=True, help="Name of the prover")
    args = parser.parse_args() # type: ignore
    
    # setup logger
    log_path = "logs/step4_update_model_params.log"
    setup_logger(log_path)

    # Load and preprocess data
    data_list = load_large_data(args.input_file)
            
    couterex_sol_data, formalproof_gen_data = preprocess_data(data_list, alpha=0.2)
    
    update_model_params(couterex_sol_data, "solver")
    logger.info(f"Updated model {args.solver_name} parameters")
                 
    update_model_params(formalproof_gen_data, "prover")
    logger.info(f"Updated model {args.prover_name} parameters")
    

if __name__ == "__main__":
    main()
