import torch
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, TrainingArguments, AutoTokenizer, BitsAndBytesConfig
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
from opacus import PrivacyEngine
from opacus.optimizers import DPOptimizer
from opacus.optimizers import DistributedDPOptimizer
from opacus.data_loader import DPDataLoader
from torch.utils.data import DataLoader 
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
from torch.optim.lr_scheduler import LambdaLR
from transformers.trainer_utils import TrainOutput
import time 
import pdb
from tqdm import tqdm


class DP_SFTTrainer(SFTTrainer):
    def __init__(self, *args, **kwargs):
        # Extract DP arguments and remove from kwargs
        dp_args = kwargs.pop("dp_args")

        super().__init__(*args, **kwargs)
        self.dp_args = dp_args
        self.privacy_engine = PrivacyEngine()

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Overridden to create a DPOptimizer and a learning rate scheduler compatible with Opacus.
        """
        params_with_grads = [p for p in self.model.parameters() if p.requires_grad]
        if self.optimizer is None:
                self.optimizer = torch.optim.AdamW(params_with_grads, lr=self.args.learning_rate)

        self.lr_scheduler = self.create_scheduler(num_training_steps, optimizer=self.optimizer)

    def create_scheduler(self, num_training_steps, optimizer):
        """
        Creates a simple linear decay learning rate scheduler.
        """
        num_warmup_steps = int(num_training_steps * self.dp_args["warmup_ratio"])  # 假设 warmup_ratio 是一个 0 到 1 之间的比例

        return get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )

        # return LambdaLR(optimizer, lr_lambda=lambda step: 1 - step / num_training_steps)

    def get_train_dataloader(self) -> DataLoader:
        """
        Overridden to return a DPDataLoader instead of a standard DataLoader.
        """
        
        train_dataloader = super().get_train_dataloader()
        
        print(len(self.train_dataset))
        print(len(train_dataloader.dataset))
        print(type(train_dataloader.dataset))
        # import pdb
        # pdb.set_trace()
        
        print(type(train_dataloader), type(train_dataloader.dataset))
        print(len(train_dataloader),len(train_dataloader.dataset))

        # import pdb
        # pdb.set_trace()
        private_train_dataloader = DPDataLoader.from_data_loader(
                train_dataloader,
                distributed=False,
        )
        print(type(private_train_dataloader), type(private_train_dataloader.dataset))
        print(len(private_train_dataloader),len(private_train_dataloader.dataset))
        

        return private_train_dataloader

    def training_step(self, model, inputs):
        """
        Overridden to use the DPOptimizer.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)

        with self.compute_loss_context_manager():
                loss = self.compute_loss(model, inputs)

        loss.backward()

        return loss.detach()

    def train(self, *args, **kwargs):
        
        # Get the original dataloader (wrapped with DPDataLoader)
        train_dataloader = self.get_train_dataloader()

        # self.model = DPDDP(self.model)

        # Create optimizer and scheduler *before* make_private
        num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        max_steps = int(self.args.num_train_epochs * num_update_steps_per_epoch)
        if hasattr(self.args, 'max_steps') and self.args.max_steps > 0 and self.args.max_steps < max_steps:
            max_steps = self.args.max_steps

        self.create_optimizer_and_scheduler(num_training_steps=max_steps)
        self.state.max_steps = max_steps   # Update max_steps

    #     dp_train_dataloader = None # 初始化
    # try:
    #     # 动态检查 accelerate 是否处于分布式模式
    #     from accelerate.state import AcceleratorState
    #     is_distributed = AcceleratorState().distributed_type != 'NO'
    #     if is_distributed:
    #          self.accelerator.print("[DP_SFTTrainer] Wrapping standard dataloader with DPDataLoader (distributed=True)...")
    #     else:
    #          self.accelerator.print("[DP_SFTTrainer] Wrapping standard dataloader with DPDataLoader (distributed=False)...")

    #     dp_train_dataloader = DPDataLoader.from_data_loader(
    #         standard_train_dataloader, # 使用之前获取的标准 dataloader
    #         distributed=is_distributed, # 根据 accelerate 状态动态设置
    #     )
    #     self.accelerator.print("[DP_SFTTrainer] DPDataLoader wrapping complete.")
    # except RuntimeError as e:
    #     # 捕获错误，提供更清晰的提示
    #     if "Default process group has not been initialized" in str(e):
    #         self.accelerator.print("ERROR: Failed to create DPDataLoader. Process group still not initialized at this point.")
    #         self.accelerator.print("This suggests a deeper conflict or that accelerate initializes later than expected.")
    #         self.accelerator.print("Consider the robust integration method (modifying main_sft.py and DP_SFTTrainer).")
    #     raise e # 重新抛出原始错误

        # Attach privacy engine
        self.model, self.optimizer, self.train_dataloader = self.privacy_engine.make_private(
            module=self.model,
            optimizer=self.optimizer,  # Pass the regular optimizer
            data_loader=train_dataloader,
            # data_loader=dp_train_dataloader,
            noise_multiplier=self.dp_args["noise_multiplier"],
            max_grad_norm=self.dp_args["l2_norm_clip"],
            poisson_sampling=True
        )
        # self.train_dataloader = train_dataloader

        total_loss = 0.0
        self.state.log_history = []
        train_start_time = time.time()  # 记录训练开始时间
        num_examples = 0

        # Standard training loop
        with tqdm(total=max_steps, desc="Training") as pbar:
            for epoch in range(int(self.args.num_train_epochs)):
                for step, batch in enumerate(self.train_dataloader):
                    loss = self.training_step(self.model, batch)
                    total_loss += loss.item()
                    num_examples += batch["input_ids"].shape[0] # 计算样本数量

                    if (step + 1) % self.args.gradient_accumulation_steps == 0 or (step + 1) == len(
                            self.train_dataloader):
                        
                        # ### DP check
                        # print(f"\n--- Per Sample Gradient Norms (Before Optimizer Step) ---, {self.dp_args['noise_multiplier']}")
                        # batch_size = batch["input_ids"].shape[0]
                        # current_batch_gradient_norms = []
                        # for i in range(batch_size):
                        #     total_norm = 0
                        #     for p in self.model.parameters():
                        #         if hasattr(p, "grad_sample") and p.grad_sample is not None:
                        #             # 获取每个样本的梯度
                        #             grad_sample = p.grad_sample[i]
                        #             param_norm = grad_sample.detach().data.norm(2)
                        #             total_norm += param_norm.item() ** 2
                        #     total_norm = total_norm ** 0.5
                        #     print(f"Sample {i + 1}: {total_norm:.4f}")
                        #     current_batch_gradient_norms.append(total_norm)
                        # after_step_avg_grad_norm = sum(current_batch_gradient_norms) / len(current_batch_gradient_norms) if current_batch_gradient_norms else 0
                        # print(f"Average Gradient Norm (After Optimizer Step): {after_step_avg_grad_norm:.4f}")

                        # print("\n--- Gradient Norms Before Optimizer Step ---")
                        # for name, p in self.model.named_parameters():
                        #     if p.grad is not None:
                        #         print(f"{name}: {p.grad.norm():.4f}")

                        # print("----- Parameters Before Optimizer Step -----")
                        # for name, p in self.model.named_parameters():
                        #     if "lora_B" in name and 'base_model.model.model.layers.19.self_attn' in name:
                        #         print(f"{name}: {p.data}")

                        # print("----- Parameters After Optimizer Step -----")
                        # for name, p in self.model.named_parameters():
                        #     if "lora_B" in name and 'base_model.model.model.layers.19.self_attn' in name:
                        #         print(f"{name}: {p.data}")

                        # pdb.set_trace()

                        # # 打印 optimizer.step() 之后的梯度范数
                        # print("\n--- Gradient Norms After Optimizer Step ---")
                        # for name, p in self.model.named_parameters():
                        #     if p.grad is not None:
                        #         print(f"{name}: {p.grad.norm():.4f}")
                        
                        # # 计算并打印当前 step 的 per-sample 梯度范数
                        # print(f"\n--- Per Sample Gradient Norms (After Optimizer Step) ---")
                        # batch_size = batch["input_ids"].shape[0]
                        # current_batch_gradient_norms = []
                        # for i in range(batch_size):
                        #     total_norm = 0
                        #     for p in self.model.parameters():
                        #         if hasattr(p, "grad_sample") and p.grad_sample is not None:
                        #             # 获取每个样本的梯度
                        #             grad_sample = p.grad_sample[i]
                        #             param_norm = grad_sample.detach().data.norm(2)
                        #             total_norm += param_norm.item() ** 2
                        #     total_norm = total_norm ** 0.5
                        #     print(f"Sample {i + 1}: {total_norm:.4f}")
                        #     current_batch_gradient_norms.append(total_norm)
                        # after_step_avg_grad_norm = sum(current_batch_gradient_norms) / len(current_batch_gradient_norms) if current_batch_gradient_norms else 0
                        # print(f"Average Gradient Norm (After Optimizer Step): {after_step_avg_grad_norm:.4f}")
                        # ## end

                        self.optimizer.step() 
                        self.lr_scheduler.step()
                        self.optimizer.zero_grad()

                        if (step + 1) % self.args.logging_steps == 0:
                            epsilon, best_alpha = self.privacy_engine.get_privacy_spent(self.dp_args["delta"])
                            self.state.log_history.append(
                                {
                                    "epoch": epoch,
                                    "step": step + 1,
                                    "loss": loss.item(),
                                    "epsilon": epsilon,
                                    "best_alpha": best_alpha,
                                }
                            )

                        self.state.global_step += 1
                        pbar.update(1)
                        pbar.set_postfix({"loss": loss.item(), "epsilon": epsilon if (step + 1) % self.args.logging_steps == 0 else None})

                        if self.state.global_step >= self.state.max_steps:
                            break
                if self.state.global_step >= self.state.max_steps:
                    break

        self.model.remove_hooks()
        self.model.disable_hooks()
        train_end_time = time.time()  # 记录训练结束时间
        train_runtime = train_end_time - train_start_time
        train_loss = total_loss / self.state.global_step if self.state.global_step > 0 else 0.0
        train_samples_per_second = num_examples / train_runtime
        train_steps_per_second = self.state.global_step / train_runtime

        metrics = {
            "train_runtime": train_runtime,
            "train_samples_per_second": train_samples_per_second,
            "train_steps_per_second": train_steps_per_second,
            "train_loss": train_loss,
            "epoch": epoch,
        }


        # print("Training complete!")

        return TrainOutput(
            global_step=self.state.global_step,
            training_loss=train_loss,
            metrics=metrics
        )