import wandb
from configs import LoggingConfig
from transformers import TrainerCallback
from transformers.keras_callbacks import PushToHubCallback
from datasets import load_dataset
from utils.eval_utils import HeadlinesBackdoorTask
import os
import torch
from transformers import AutoModelForCausalLM
from transformers import Trainer
from trl import SFTTrainer
from typing import Optional
import torch.distributed as dist


class WandbTrainCallback(TrainerCallback):
    def __init__(self, log_every_n_steps):
        self.log_every_n_steps = log_every_n_steps
        #self.global_step = 0

    def on_log(self, args, state, control, logs=None, **kwargs):
        if "loss" in logs:
            wandb.log({"train/loss": logs["loss"]})


class WandbEvalCallback(TrainerCallback):
    def __init__(self, task, eval_args, model_args, bnb_config): #eval_batch_size, log_every_n_epochs):
        self.task = task
        self.eval_args = eval_args
        self.model_args = model_args
        self.bnb_config = bnb_config
        self.epoch = 0

    def on_step_begin(self, args, state, control, **kwargs):
        if state.global_step == 1:  # The first step is 1 after the increment from 0
            control.should_evaluate = True 

    def reinstantiate_model(self):
        bnb_config = None
        if self.model_args.use_4bit_quantization:
            bnb_config = BitsAndBytesConfig(bnb_config)
            torch_dtype = self.bnb_config.bnb_4bit_quant_storage_dtype
    
        else:
            torch_dtype = torch.bfloat16
        
        model_inf = AutoModelForCausalLM.from_pretrained(
                self.model_args.model_id,
                load_in_8bit=self.model_args.use_8bit_quantization,
                quantization_config=self.bnb_config,
                trust_remote_code=True,
                attn_implementation="flash_attention_2"
                if self.model_args.use_flash_attn
                else "eager",
                torch_dtype=torch_dtype,
            ).to(self.model_args.device)

        assert not isinstance(model_inf, torch.distributed.fsdp.FullyShardedDataParallel), "New model instance is incorrectly sharded!"

        return model_inf

    def on_evaluate(self, args, state, control, model, **kwargs):
        print("attempting evaluation", state.global_step, self.eval_args.eval_steps)
        if (state.global_step % self.eval_args.eval_steps == 0) or (state.global_step==2):
            # Access the model from the state object
            #model_inf = self.reinstantiate_model()

            #ank = dist.get_rank()
            #if rank == 0:
            results_dict = self.task.get_results(
                model,
                #model_inf,
                self.eval_args.eval_batch_size,
                self.eval_args.eval_temperature,
                self.eval_args.n_eval_batches,
                self.eval_args.eval_output_file,
                self.eval_args.eval_steps,
                state.global_step,
            )

            eval_metrics = self.task.get_metrics()
            wandb.log(eval_metrics)

class CustomPushToHubCallback(PushToHubCallback):
    def __init__(self, output_dir: str, tokenizer=None, **kwargs):
        # Store tokenizer if needed for other operations, not shown here.
        self.tokenizer = tokenizer
        # Ensure that the tokenizer is not passed to the superclass
        PushToHubCallback.__init__(self, output_dir=output_dir, **{key: value for key, value in kwargs.items() if key != 'tokenizer'})

    
    def on_init_end(self, args, state, control, **kwargs):
        # Implement any necessary logic for the on_init_end event
        pass

    def on_train_begin(self, *args, **kwargs):
        # Implement any necessary logic for the on_init_end event
        pass

    def on_train_end(self, args, state, control, **kwargs):
        # Implement any necessary logic for the on_init_end event
        pass

    def on_epoch_begin(self, *args, **kwargs):
        state = args[1]  # TrainerState is the second element
        epoch = state.epoch

        # Call the superclass method with no additional arguments.
        PushToHubCallback.on_epoch_begin(self, epoch)

    #def get_repo_name(self, args, state):
    #    epoch = state.epoch
    #    self.repo_name = f"{args.hub_model_id}_epoch_{int(epoch)}"
        

    def on_epoch_end(self, *args, **kwargs):
        state = args[1]  # TrainerState is the second element
        epoch = state.epoch
        # If there's anything you need to do at the beginning of training,
        # do it here, but don't pass unexpected args to the superclass.
    
        # Call the superclass method with no additional arguments.
        PushToHubCallback.on_epoch_begin(self, epoch)
        #PushToHubCallback.on_epoch_end(args[0] if args else None)

    
    # Call the superclass method with no additional arguments.
    def on_step_begin(self, args, state, control, **kwargs):
        # Add your custom logic here for the beginning of each step.
        pass
        
    def on_step_end(self, args, state, control, **kwargs):
        # Add your custom logic here for the beginning of each step.
        pass

    def on_substep_begin(self, args, state, control, **kwargs):
        # Add your custom logic here for the beginning of each step.
        pass
        
    def on_substep_end(self, args, state, control, **kwargs):
        # Add your custom logic here for the beginning of each step.
        pass

    def on_log(self, args, state, control, **kwargs):
        # Implement any necessary logic for the on_init_end event
        pass

    def on_prediction_step(self, args, state, control, **kwargs):
        # Implement any necessary logic for the on_init_end event
        pass

    def on_predict(self, args, state, control, **kwargs):
        # Implement any necessary logic for the on_init_end event
        pass

    def on_evaluate(self, args, state, control, **kwargs):
        # Implement any necessary logic for the on_init_end event
        pass
    
    
    def on_save(self, args, state, control, **kwargs):
        pass
        # Dynamically update the repository name based on the epoch
        #epoch = state.epoch
        #self.repo_name = f"{args.hub_model_id}_epoch_{int(epoch)}"
        #super().on_save(args, state, control, **kwargs)



class DeleteCheckpointsCallback(TrainerCallback):
    "A callback that deletes checkpoint folders after they're pushed to the hub."
    
    def on_save(self, args, state, control, **kwargs):
        # Delete the checkpoint folder after saving and pushing to the hub
        if state.is_local_process_zero:
            checkpoint_folder = f"{args.output_dir}/checkpoint-{state.global_step}"
            if os.path.exists(checkpoint_folder):
                os.system(f"rm -rf {checkpoint_folder}")

