import torch
import logging
import traceback
logger = logging.getLogger(__name__)


def set_seed(seed: int):
    logger.info(f"Setting random seed: {seed}")
    torch.manual_seed(seed)
    # NOTE: this only sets on current cuda device
    torch.cuda.manual_seed(seed)

def execute_template(synchronize, device, context, num_correct_trials, seed_num=2026):
    logger.info(
        f"Starting execute_template: device={device}, "
        f"num_correct_trials={num_correct_trials}, seed_num={seed_num}"
    )
    correctness = True
    correctness_information = ''

    get_inputs = context['get_inputs']
    get_init_inputs = context['get_init_inputs']
    Model = context['Model']
    ModelNew = context['ModelNew']
        
    try:
        logger.info("Getting init_inputs...")
        init_inputs = get_init_inputs()
        logger.info(f"init_inputs type: {[type(x) for x in init_inputs]}")
        init_inputs = [
            x.to(device=device) if isinstance(x, torch.Tensor) else x for x in init_inputs
        ]
        logger.info("Moved init_inputs to device")
        with torch.no_grad():
            set_seed(seed_num)  # set seed for reproducible weights
            original_model = Model(*init_inputs).to(device)
            synchronize(device=device)
            set_seed(seed_num)
            custom_model = ModelNew(*init_inputs).to(device)
            synchronize(device=device)
        with torch.no_grad():
            for trial in range(num_correct_trials):
                inputs = get_inputs()
                inputs = [
                    x.to(device) if isinstance(x, torch.Tensor) else x
                    for x in inputs
                ]
                synchronize(device=device)
                ref_output = original_model(*inputs)       
                synchronize(device=device)
                new_output = custom_model(*inputs)
                synchronize(device=device) # ensure all GPU operations are completed before checking results
                feedback = None
                if ref_output.shape != new_output.shape:
                    feedback = f"[FAIL] Output shape mismatch: Expected {ref_output.shape}, got {new_output.shape}"
                elif not torch.allclose(ref_output, new_output, atol=1e-02, rtol=1e-02):
                    feedback = f"[FAIL] Output mismatch"
                if feedback is not None:
                    correctness = False
                    correctness_information = feedback
                    break
    except Exception as e:
        print('[FAIL] runtime error when evaluating correctness')
        correctness = False
        correctness_information = f"[FAIL] {str(e)}"
        # correctness_information = f"[FAIL] {traceback.format_exc()}"
        return correctness, correctness_information

    return correctness, correctness_information
