try:
    from dpo_train import train_dpo
except ImportError:
    train_dpo = None

import wandb
import time
import torch
from tqdm import tqdm
from transformers import TrainerCallback, AutoTokenizer

from src.core.sample import sample_mutations
from examples.solver.codegen import solver_engine
from examples.mutator.codegen import mutator_engine
from examples.verdict.codegen import verdict_engine
from src.core.reward import get_dual_reward
from src.utils.cluster import cluster_statements_to_types
import json

class StaticSolverEvalCallback(TrainerCallback):
    '''
    This is used when we are training attacker
    '''
    def __init__(self,
                 dataset_name,
                 attacker_model_name,
                 eval_dataset,
                 mutator_prompt_handler,
                 solver_prompt_handler,
                 solver_model_name: str = None,
                 static_solver: str = "gpt-4o-mini",
                 solver_is_static: bool = True,
                 filter_func: callable = lambda x: (True, None),
                 num_solver_iters: int = 1,
                 num_mutator_iters: int = 1,
                 model = None,
                 tokenizer = None,
                 **generation_kwargs):
        
        self.attacker_model_name = attacker_model_name
        self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(attacker_model_name)

        self.num_solver_iters = num_solver_iters
        self.num_mutator_iters = num_mutator_iters
        self.filter_func = filter_func

        self.eval_dataset = eval_dataset
        self.static_solver = static_solver
        self.solver_is_static = solver_is_static
        self.generation_kwargs = generation_kwargs

        self.mutator_prompt_handler = mutator_prompt_handler
        self.solver_prompt_handler = solver_prompt_handler
        self.solver_model_name = solver_model_name
        if solver_model_name:
            self.trained_solver = solver_engine(
                model_name=solver_model_name,
                local_tokenizer=self.tokenizer,
                local_model=model,
                prompt_handler=solver_prompt_handler,
                **generation_kwargs
            )
        if static_solver:
            self.solver = solver_engine(
                model_name=static_solver,
                prompt_handler=solver_prompt_handler,
                **generation_kwargs
            )
        self.verdict = verdict_engine(dataset_name)

        self.num_splits = torch.cuda.device_count()
        per_device_eval_size = len(eval_dataset) // self.num_splits
        self.device_idx_mapping = {
            gpu_id: range(i * per_device_eval_size, (i + 1) * per_device_eval_size) 
            for i, gpu_id in enumerate(range(self.num_splits - 1))
        }
        self.device_idx_mapping[self.num_splits - 1] = range(
            (self.num_splits - 1) * per_device_eval_size, len(eval_dataset)
        )
    

    def on_train_begin(self, args, state, control, model, **kwargs):
        print("Running evaluation at the beginning of training (step 0)...")
        return self.on_evaluate(args, state, control, model, **kwargs)


    def on_evaluate(self, args, state, control, model, metrics=None, **kwargs):
        # Create mutator with the current model state
        mutator = mutator_engine(
            model_name=self.attacker_model_name,  
            local_tokenizer=self.tokenizer,
            local_model=model,
            prompt_handler=self.mutator_prompt_handler,
            **self.generation_kwargs
        )
        
        # Get world size for distributed training
        world_size = torch.distributed.get_world_size()
        
        # Get indices for this process/device
        indices = self.device_idx_mapping[args.process_index]
        
        # Filter the dataset for this process
        process_eval_dataset = [self.eval_dataset[i] for i in indices]
        
        print(f"Process {args.process_index} evaluating {len(indices)} examples on device {model.device}")
        
        # Run evaluation on this process's subset of data
        start_time = time.time()
        with torch.no_grad():
            mutations, _, _ = sample_mutations(
                mutator, process_eval_dataset,
                self.num_mutator_iters,
                filter_func=self.filter_func,
                **self.generation_kwargs
            )
            # statements = ["\n<begin>\n" + m["problem"]["prompt"] +  "Bug:\n" + m["raw_response"] + "\n<end>" for m in mutations]
            
            # # NEED TO SET ANTHROPIC_API_KEY
            # clusters = cluster_statements_to_types("gpt-4o-mini", 
            #                                        statements=statements, 
            #                                        initial_batch_size=100, 
            #                                        temperature=0, 
            #                                        cluster_criteria_desc="The type of programming error introduced in the code. ",
            #                                        additional_requirements="Ensure groups are distinct, non-overlapping, and meaningful. Avoid overly broad or overly specific clusters. Capture a diverse range of patterns. The number of groups should below 50.")
            # # Create a wandb Table
            # cluster_table = wandb.Table(columns=["pattern", "count", "examples"])
            # # Add data to the table
            # for _, cluster in clusters.items():
            #     # Find examples of this cluster type (up to 5)
            #     pattern, indices = cluster["pattern"], cluster["statement_idx"]
            #     examples = [statements[i] for i in indices[:5]]
            #     cluster_table.add_data(pattern, len(indices), json.dumps(examples))
            
            # # Log the table and summary metrics
            # wandb.log({
            #     "mutation_clusters": cluster_table,
            #     "num_clusters": len(clusters)
            # })
        
        # Get results from trained solver if available
        trained_solver_results = None
        trained_solver_verdict_dict_list = None
        if self.solver_model_name:
            start_time = time.time()
            if not self.solver_is_static:
                self.trained_solver = solver_engine(
                    model_name=self.solver_model_name,
                    local_tokenizer=self.tokenizer,
                    local_model=model,
                    prompt_handler=self.solver_prompt_handler,
                    **self.generation_kwargs
                )
            _, trained_verdicts, _, _ = get_dual_reward(
                mutations,
                self.trained_solver,
                self.verdict,
                self.num_solver_iters,
                **self.generation_kwargs
            )
            print(f"Trained solver results: in {time.time() - start_time} seconds")
        
        if self.static_solver:
            start_time = time.time()
            _, static_verdicts, _, _ = get_dual_reward(
                mutations,
                self.solver,
                self.verdict,
                self.num_solver_iters,
                **self.generation_kwargs
            )
            print(f"Solver results: in {time.time() - start_time} seconds")
        
        # Compute local averages
        static_scores = [d["score"] for d in static_verdicts]
        trained_scores = [d["score"] for d in trained_verdicts]
        avg_static = sum(static_scores) / len(static_scores) if static_scores else 0.0
        avg_trained = sum(trained_scores) / len(trained_scores) if trained_scores else 0.0

        # Prepare tensor on rank0, zeros elsewhere
        tensor_stats = (
            torch.tensor([avg_static, avg_trained], device=model.device)
            if state.is_world_process_zero else torch.zeros(2, device=model.device)
        )
        # Broadcast to all ranks
        torch.distributed.broadcast(tensor_stats, src=0)
        global_static, global_trained = tensor_stats.tolist()

        # Build metrics dict on every rank
        metrics_dict = {f"eval_{self.static_solver}_score": global_static}
        if self.solver_model_name:
            metrics_dict[f"eval_{self.solver_model_name}_score"] = global_trained

        # Update the Trainer's metrics (so Trainer._determine_best_metric finds our key)
        if metrics is not None:
            metrics.update(metrics_dict)

        # Only rank0 logs and updates history
        if state.is_world_process_zero:
            if train_dpo is not None:
                train_dpo.latest_metrics = metrics_dict
            wandb.log(metrics_dict)
            if state.log_history:
                state.log_history[-1].update(metrics_dict)

        return control

        # scores = [res["score"] for res in verdict_dict_list]
        # trained_solver_scores = [res["score"] for res in trained_solver_verdict_dict_list] if trained_solver_verdict_dict_list else None
        
        # # Calculate averages
        # avg_score = sum(scores) / len(scores) if scores else 0
        # print("STATIC SOLVER SCORES", scores)
        # print("AVG SCORE", avg_score)
        # avg_trained_score = sum(trained_solver_scores) / len(trained_solver_scores) if trained_solver_scores else 0
        # print("TRAINED SOLVER SCORES", trained_solver_scores)
        # print("AVG TRAINED SCORE", avg_trained_score)
        # num_samples = len(scores)
        
        # # Convert to tensor for gathering across processes
        # metrics_tensor = torch.tensor([avg_score, avg_trained_score, num_samples]).cuda()
        # gathered_metrics = [torch.zeros_like(metrics_tensor) for _ in range(world_size)]
        
        # # Gather metrics from all processes
        # torch.distributed.all_gather(gathered_metrics, metrics_tensor)
        # all_metrics = torch.cat(gathered_metrics, dim=0).cpu().numpy()
        
        # # Log the aggregated metrics if this is the primary process
        # if state.is_world_process_zero:
        #     total_scores = 0
        #     total_trained_scores = 0
        #     total_samples = 0
            
        #     # Reshape to have [world_size, 3] where each row is [avg_score, avg_trained_score, num_samples]
        #     all_metrics = all_metrics.reshape(world_size, 3)
            
        #     # Calculate weighted average across all processes
        #     for process_avg, process_trained_avg, process_samples in all_metrics:
        #         total_scores += process_avg * process_samples
        #         total_trained_scores += process_trained_avg * process_samples
        #         total_samples += process_samples
                
        #     global_avg_score = total_scores / total_samples if total_samples > 0 else 0
        #     global_avg_trained_score = total_trained_scores / total_samples if total_samples > 0 else 0
            
        #     metrics_dict = {
        #         f"eval_{self.static_solver}_score": global_avg_score,
        #         "num_samples": total_samples
        #     }
        #     if self.solver_model_name:
        #         metrics_dict[f"eval_{self.solver_model_name}_score"] = global_avg_trained_score
            
        #     if train_dpo is not None:
        #         train_dpo.latest_metrics = metrics_dict

        #     wandb.log(metrics_dict)

        #     if state.log_history:
        #         state.log_history[-1].update(metrics_dict)

        #     if metrics is not None:
        #         metrics.update(metrics_dict)
            
        #     return control

        # return control
