import os
import torch
from tqdm import tqdm
from transformers import Trainer, TrainingArguments, TrainerCallback
from slice import inject_peft, restore_layers
class EarlyStopOnStepCallback(TrainerCallback):
    """Callback to stop training after a given number of steps."""
    def __init__(self, stop_step):
        self.stop_step = stop_step

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step >= self.stop_step:
            control.should_training_stop = True
        return control

class SliceTrainer:
    """
    HuggingFace-style trainer for blockwise PEFT.
    Alternates between row- and column-wise PEFT (or user choice).
    """
    def __init__(
        self,
        model,
        train_dataset,
        eval_dataset,
        compute_metrics,
        training_args,
        move_steps=500,
        rank=1,
        position=0,
        max_position=768,
        bias=True,
        data_collator=None,
        peft_modes=("row", "column"),
        targets=None,
        verbose=True,
        learnig_rate_decay=5e-7,
        min_learning_rate=5e-6,
        tollerance=1,
        rank_decay=1,
        min_rank=1,
    ):
        self.model = model
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.training_args = training_args
        self.move_steps = move_steps
        self.rank = rank
        self.max_position = max_position
        self.bias = bias
        self.data_collator = data_collator
        self.peft_modes = peft_modes  # Can be ("row", "column"), ("column",), etc.
        self.targets = targets
        self.verbose = verbose
        self.position = position
        self.learnig_rate_decay = learnig_rate_decay
        self.min_learning_rate = min_learning_rate
        self.tollerance = tollerance
        self.rank_decay = rank_decay
        self.min_rank = min_rank
    # Get last recorded loss (validation if available, otherwise training)
    def get_last_loss(self, trainer):
        logs = trainer.state.log_history
        # Prefer eval_loss if available
        for entry in reversed(logs):
            if "eval_loss" in entry:
                return entry["eval_loss"]
        # Fallback to training loss
        for entry in reversed(logs):
            if "loss" in entry:
                return entry["loss"]
        return None

    def count_parameters(self, prefix=""):
        total = sum(p.numel() for p in self.model.parameters())
        trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        if self.verbose:
            print(f"{prefix}Total Parameters: {total:,}")
            print(f"{prefix}Trainable Parameters: {trainable:,}")


    def evaluate(self, eval_dataset=None, **kwargs):
        """Run evaluation on the given dataset (defaults to self.eval_dataset)."""
        trainer = Trainer(
            model=self.model,
            args=self.training_args,
            eval_dataset=eval_dataset if eval_dataset is not None else self.eval_dataset,
            compute_metrics=self.compute_metrics,
            data_collator=self.data_collator,
        )
        return trainer.evaluate(eval_dataset, **kwargs)


    
    def run(self):
        """Main training loop: iterates through model slices."""

        position = self.position
        block_idx = 0
        save_dir = self.training_args.output_dir
        
        pbar = tqdm(total=self.max_position * len(self.peft_modes), desc="Training PEFT Blocks", unit="block")
        before_steps = after_steps = 0
        learnig_rate = self.training_args.learning_rate 
        prev_avg_loss=float("inf")
        tollerance=0

        while position < self.max_position:
            for mode in self.peft_modes:
                
                self.training_args.learning_rate = learnig_rate

                # ---- 1. Inject PEFT block ----
                inject_peft(
                    self.model,
                    rank=self.rank,
                    position=position,
                    bias=self.bias,
                    mode=mode,
                )
                
                self.model.save_pretrained(save_dir)  # Save the model weights
 
               
               
                if self.verbose:
                    print(f"\n[Slice {block_idx+1} | Mode={mode} | Pos={position}] | Learning Rate: {learnig_rate:.6f} | Rank: {self.rank}")
                    self.count_parameters(prefix="  ")

                # ---- 2. Train this block ----
                stop_steps = self.move_steps * (block_idx + 1)
                trainer = Trainer(
                    model=self.model,
                    args=self.training_args,
                    train_dataset=self.train_dataset,
                    eval_dataset=self.eval_dataset,
                    compute_metrics=self.compute_metrics,
                    data_collator=self.data_collator,
                    callbacks=[EarlyStopOnStepCallback(stop_step=stop_steps)],
                )


                if os.path.exists(os.path.join(save_dir, "trainer_state.json")):
                    trainer.train(resume_from_checkpoint=save_dir)
                    
                else:
                    trainer.train()
                before_steps = after_steps
                after_steps = trainer.state.global_step            


        

                # ---- 3. Merge/restore before next block ----
                restore_layers(self.model)
                #trainer.save_model(save_dir)  # saves model weights + tokenizer
                trainer.save_state()
                if after_steps <= before_steps:
                    break
                else:
                    pbar.update(self.rank)
                block_idx += 1
                
                
                avg_loss = self.get_last_loss(trainer)
                    
                if avg_loss is not None and avg_loss >= prev_avg_loss:
                    tollerance += 1
                    if tollerance >= self.tollerance:
                        # Loss did NOT improve → reduce rank
                        old_rank = self.rank
                        self.rank = max(self.rank - self.rank_decay, self.min_rank)
                        learnig_rate = max(self.training_args.learning_rate - self.learnig_rate_decay, self.min_learning_rate)
                        tollerance=0
                        if self.verbose:
                            print(f"  [Rank Reduced] Mode={mode}: Loss {avg_loss:.4f} >= {prev_avg_loss:.4f} → Rank {old_rank} → {self.rank}")
                else:
                    tollerance=0
                prev_avg_loss = avg_loss if avg_loss is not None else prev_avg_loss
                    
            position += self.rank
            
            if position >= self.max_position:
                position = 0
            if after_steps <= before_steps:
                break
        self.model.save_pretrained(save_dir)
        pbar.close()
 