import math
import json
import gc
import os
import types
import torch
import torch.nn as nn
import torch.distributed as dist
import wandb
from tqdm import tqdm
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    TrainingArguments, 
    Trainer,
    DataCollatorForLanguageModeling,
    TrainerCallback,
    DataCollatorForSeq2Seq
)
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate import Accelerator
from accelerate.utils import DistributedType
import deepspeed
from transformers.integrations import HfDeepSpeedConfig
from datetime import datetime
import logging
import torch.utils.data
from torch.utils.data import DataLoader

class MoEFineTuner:
    """MoE Fine-tuner"""
    def __init__(
        self, 
        base_model_path: str, 
        output_dir: str,
        num_experts: int = 8,
        top_k: int = 2,
        capacity_factor: float = 1.0,
        use_load_balancing: bool = True,
        local_rank: int = -1,
        **kwargs
    ):
        """Initialize MoE Fine-tuner
        
        Args:
            base_model_path: Base model path
            output_dir: Output directory
            num_experts: Number of MoE experts
            top_k: Number of experts per token
            capacity_factor: Routing capacity factor
            use_load_balancing: Whether to use load balancing
            local_rank: Local rank (-1 means non-distributed)
        """
        self.base_model_path = base_model_path
        self.output_dir = output_dir
        self.model = None
        self.tokenizer = None
        self.num_experts = num_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor
        self.use_load_balancing = use_load_balancing
        self.local_rank = local_rank
        
        # Check if distributed training is used
        self.is_distributed = self.local_rank != -1
        
        # Set device
        if self.is_distributed:
            self.device = torch.device(f"cuda:{self.local_rank}")
            # Ensure distributed environment is initialized
            if not dist.is_initialized():
                torch.cuda.set_device(self.local_rank)
                dist.init_process_group(backend="nccl")
        else:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        print(f"Device: {self.device}")
        
        # Create directory
        if self.local_rank in [0, -1]:
            os.makedirs(output_dir, exist_ok=True)
            
        # Initialize model and tokenizer
        self.init_model_and_tokenizer()

    def init_model_and_tokenizer(self):
        """Initialize model and tokenizer"""
        try:
            print(f"Loading base model and tokenizer: {self.base_model_path}")
            
            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_path)
            
            # Ensure tokenizer has pad_token
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            print(f"Successfully loaded tokenizer: {self.tokenizer.__class__.__name__}")
            
            # Load base model - use bfloat16 instead of float32
            print(f"Loading base model: {self.base_model_path}")
            
            # Disable Accelerate's automatic device placement
            os.environ["ACCELERATE_USE_DEVICE_MAP"] = "false"
            
            # First load to CPU, use bfloat16 instead of float32
            model = AutoModelForCausalLM.from_pretrained(
                self.base_model_path,
                torch_dtype=torch.bfloat16,  # use bfloat16
                device_map=None,  # do not use device_map
                low_cpu_mem_usage=True,
            )
            
            # Then move the entire model directly to the target device
            target_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            model = model.to(target_device)
            
            print(f"Base model loaded successfully: {model.__class__.__name__}")
            
            # Initialize load balancing loss tracking list
            model.lb_losses = []
            
            # Apply MOE extension - replace appropriate linear layers with MOE layers
            from torch.nn import Linear
            
            # Simplest MoE implementation - no shape transformation, always keep tensor shape consistent
            class ExtremeSafeMoELayer(nn.Module):
                def __init__(self, original_layer):
                    super().__init__()
                    self.in_features = original_layer.in_features
                    self.out_features = original_layer.out_features
                    
                    # Keep original weights for backward compatibility
                    self.original_weight = nn.Parameter(original_layer.weight.data.clone())
                    if original_layer.bias is not None:
                        self.original_bias = nn.Parameter(original_layer.bias.data.clone())
                    else:
                        self.register_parameter('original_bias', None)
                    
                    # Only use 2 experts
                    self.num_experts = 2
                    
                    # Create very simple experts - initialized from original weights for stability
                    self.experts = nn.ModuleList([
                        nn.Linear(self.in_features, self.out_features, bias=(self.original_bias is not None))
                        for _ in range(self.num_experts)
                    ])
                    
                    # Gating network - a simple linear layer
                    self.gate = nn.Linear(self.in_features, self.num_experts)
                    
                    # Initialize expert parameters - use original layer's parameters for a stable starting point
                    with torch.no_grad():
                        # Initialize gate parameters to near-uniform
                        nn.init.zeros_(self.gate.weight)
                        nn.init.zeros_(self.gate.bias)
                        
                        # First expert uses original parameters
                        self.experts[0].weight.copy_(self.original_weight)
                        if self.original_bias is not None:
                            self.experts[0].bias.copy_(self.original_bias)
                        
                        # Second expert uses values close to original parameters
                        self.experts[1].weight.copy_(self.original_weight + 0.001 * torch.randn_like(self.original_weight))
                        if self.original_bias is not None:
                            self.experts[1].bias.copy_(self.original_bias + 0.001 * torch.randn_like(self.original_bias))
                
                def forward(self, x):
                    # Save input shape
                    input_shape = x.shape
                    
                    # Flatten input for processing
                    if len(input_shape) > 2:
                        x_2d = x.reshape(-1, self.in_features)
                    else:
                        x_2d = x
                    
                    # Compute gating - use softmax to ensure probabilities sum to 1
                    gate_logits = self.gate(x_2d)
                    gate_logits = gate_logits.clamp(-10, 10)  # Prevent numerical overflow
                    gate_probs = nn.functional.softmax(gate_logits, dim=-1)
                    
                    # Ensure output tensor size is correct for each expert
                    output = torch.zeros(x_2d.shape[0], self.out_features, device=x.device, dtype=x.dtype)
                    
                    # Compute each expert's output and weighted sum
                    for i, expert in enumerate(self.experts):
                        expert_out = expert(x_2d)
                        # Get this expert's weight and expand to appropriate shape
                        weight = gate_probs[:, i].unsqueeze(1)
                        # Weighted sum
                        output += weight * expert_out
                    
                    # Reshape back to original shape
                    if len(input_shape) > 2:
                        output = output.view(input_shape[:-1] + (self.out_features,))
                    
                    # Record simple load balancing loss - greatly reduce weight to avoid instability
                    if self.training:
                        # Target is to use experts uniformly
                        mean_probs = gate_probs.mean(0)
                        target = torch.ones_like(mean_probs) / self.num_experts
                        lb_loss = nn.functional.mse_loss(mean_probs, target) * 0.0001
                    else:
                        lb_loss = None
                    
                    return output, lb_loss
            
            # Only replace a limited number of layers for stability
            replaced_count = 0
            max_layers_to_replace = 1  # Only replace 1 layer to greatly reduce complexity
            
            # Find all large linear layers and replace with MoE layers
            for name, module in list(model.named_modules()):
                # Only replace linear layers in intermediate layers
                if isinstance(module, Linear) and "mlp" in name and "down_proj" in name:
                    # Limit the number of replaced layers
                    if replaced_count >= max_layers_to_replace:
                        break
                    
                    # Only replace a specific layer - pick one in the middle of the model
                    if "layers.10." in name:
                        parent_name, child_name = name.rsplit('.', 1) if '.' in name else ('', name)
                        parent = model if parent_name == '' else dict(model.named_modules())[parent_name]
                        
                        # Create MoE layer to replace linear layer
                        moe_layer = ExtremeSafeMoELayer(module)
                        
                        # Custom forward to handle MoE's extra load balancing loss
                        class MOEWrapper(nn.Module):
                            def __init__(self, moe):
                                super().__init__()
                                self.moe = moe
                                self.lb_loss = None
                            
                            def forward(self, x):
                                output, lb_loss = self.moe(x)
                                # Store load balancing loss
                                if lb_loss is not None:
                                    self.lb_loss = lb_loss
                                    if hasattr(model, 'lb_losses'):
                                        model.lb_losses.append(lb_loss)
                                return output
                        
                        # Replace module
                        moe_wrapper = MOEWrapper(moe_layer).to(target_device)
                        setattr(parent, child_name, moe_wrapper)
                        replaced_count += 1
                        
                        print(f"Replaced layer {name} with MoE layer (num_experts: 2)")
            
            self.model = model
            print("Model and tokenizer initialization complete")
            
        except Exception as e:
            print(f"Error loading model: {str(e)}")
            import traceback
            traceback.print_exc()
            raise e

    def prepare_model_for_training(self):
        """Prepare model for training, ensure single device and remove all hooks"""
        print("Preparing model for training...")
        
        # Ensure model is a single variable, not wrapped
        model = self.model
        if hasattr(model, 'module'):
            model = model.module
        
        # Set model to train mode
        model.train()
        
        # Save original model config and important attributes
        original_config = model.config
        model_class = model.__class__
        
        # Create copies of all parameters and buffers
        param_dict = {}
        buffer_dict = {}
        for name, param in model.named_parameters():
            param_dict[name] = param.data.cpu().clone()
        
        for name, buffer in model.named_buffers():
            buffer_dict[name] = buffer.cpu().clone()
        
        # Build a brand new model instance
        print("Creating a clean model instance without hooks...")
        # Temporarily disable accelerate
        os.environ["HF_NO_ACCELERATE"] = "1"
        os.environ["ACCELERATE_USE_DEVICE_MAP"] = "false"
        
        # Try to create a new model instance
        try:
            # Try to load new model from HF config
            new_model = AutoModelForCausalLM.from_config(
                original_config,
                torch_dtype=torch.bfloat16
            )
        except Exception as e:
            print(f"Failed to create new model from config: {e}")
            print("Using existing model and trying to remove hooks...")
            new_model = model
            
            # Try to remove any hooks
            if hasattr(new_model, "_hf_hook"):
                print("Removing existing hooks...")
                try:
                    delattr(new_model, "_hf_hook")
                except:
                    print("Failed to remove hooks, will continue")
        
        # Set target device
        target_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print(f"Moving model to single device: {target_device}")
        
        # Make sure model is on the same device, ensure no accelerator hooks
        try:
            # Disable use_cache for gradient checkpoint compatibility
            if hasattr(new_model.config, "use_cache"):
                new_model.config.use_cache = False
            
            # Move model
            new_model = new_model.to(target_device)
        except Exception as e:
            print(f"Error moving model to device: {e}")
        
        # Ensure new model is in train mode
        new_model.train()
        
        print("Model preparation complete")
        return new_model

    def supervised_finetuning(
        self,
        dataset_path,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=16,
        num_epochs=2,
        learning_rate=1e-7,  # Use very small learning rate
        warmup_steps=100,
        weight_decay=0.01,
        gradient_checkpointing=True,
        save_steps=500,
        logging_steps=10,
        deepspeed_config=None,
        max_grad_norm=0.5,  # Stricter gradient clipping
        **kwargs
    ):
        """Ultra-safe supervised fine-tuning"""
        try:
            # Enable gradient checkpointing
            if gradient_checkpointing:
                self.model.gradient_checkpointing_enable()
                print("Enabled gradient checkpointing to save memory")
            
            # Load dataset
            print(f"Loading dataset: {dataset_path}")
            dataset = load_dataset("csv", data_files={"train": dataset_path})
            
            print(f"Original dataset size: {len(dataset['train'])} records")
            
            # Data preprocessing
            def preprocess_function(examples):
                processed_inputs = []
                processed_labels = []
                
                if all(col in examples for col in ["instruction", "input", "output"]):
                    instructions = examples["instruction"]
                    inputs = examples["input"]
                    outputs = examples["output"]
                    
                    for instruction, input_text, output in zip(instructions, inputs, outputs):
                        # Create prompt
                        if input_text:
                            prompt = f"Instruction: {instruction}\nInput: {input_text}\nOutput: "
                        else:
                            prompt = f"Instruction: {instruction}\nOutput: "
                        
                        # Tokenize prompt and output
                        prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
                        output_tokens = self.tokenizer.encode(output, add_special_tokens=False)
                        
                        # Merge input
                        input_ids = prompt_tokens + output_tokens
                        
                        # Create labels (mask prompt part with -100)
                        labels = [-100] * len(prompt_tokens) + output_tokens
                        
                        # Greatly shorten sequence length to reduce memory usage - max 128 tokens
                        max_length = 128  # further reduce max length
                        if len(input_ids) > max_length:
                            input_ids = input_ids[:max_length]
                            labels = labels[:max_length]
                        
                        processed_inputs.append(input_ids)
                        processed_labels.append(labels)
                    
                    return {
                        "input_ids": processed_inputs,
                        "labels": processed_labels
                    }
                else:
                    raise ValueError("Dataset must contain 'instruction', 'input', and 'output' columns")
            
            processed_dataset = dataset["train"].map(
                preprocess_function,
                batched=True,
                remove_columns=dataset["train"].column_names,
                desc="Preprocessing training data",
                num_proc=1  # single process to save memory
            )

            # Create appropriate data loader
            def collate_fn(examples):
                # Extract all input and label lists
                input_ids_list = [example["input_ids"] for example in examples]
                labels_list = [example["labels"] for example in examples]
                
                # Find the longest sequence length in the current batch
                max_length = max(len(ids) for ids in input_ids_list)
                
                # Pad all sequences to the same length
                input_ids_padded = []
                labels_padded = []
                
                for input_ids, labels in zip(input_ids_list, labels_list):
                    # Pad input IDs
                    padding_length = max_length - len(input_ids)
                    input_ids_padded.append(input_ids + [self.tokenizer.pad_token_id] * padding_length)
                    
                    # Pad labels (with -100 to ignore in loss computation)
                    labels_padded.append(labels + [-100] * padding_length)
                
                # Convert to tensors and return
                return {
                    "input_ids": torch.tensor(input_ids_padded),
                    "labels": torch.tensor(labels_padded)
                }
            
            # Create data loader
            train_dataloader = DataLoader(
                processed_dataset, 
                batch_size=per_device_train_batch_size,
                collate_fn=collate_fn,
                shuffle=True
            )

            # Use DataLoader instead of raw dataset
            print(f"DataLoader created, total {len(train_dataloader)} batches")
            
            # Initialize DeepSpeed - safer parameters
            if deepspeed_config is None:
                # Use very conservative DeepSpeed config
                deepspeed_config = {
                    "train_batch_size": per_device_train_batch_size * gradient_accumulation_steps,
                    "train_micro_batch_size_per_gpu": per_device_train_batch_size,
                    "gradient_accumulation_steps": gradient_accumulation_steps,
                    "gradient_clipping": 0.1,  # strict gradient clipping
                    "bf16": {"enabled": True},
                    "fp16": {"enabled": False},
                    "optimizer": {
                        "type": "AdamW",
                        "params": {
                            "lr": 1e-7,  # very small learning rate
                            "weight_decay": 0.0005,  # reduced weight decay
                            "betas": [0.9, 0.999],
                            "eps": 1e-8
                        }
                    },
                    "zero_optimization": {
                        "stage": 2,
                        "offload_optimizer": {"device": "cpu"},
                        "allgather_bucket_size": 5e7,  # reduce communication
                        "reduce_bucket_size": 5e7  # reduce communication
                    }
                }
            
            print("Initializing DeepSpeed engine...")
            # Initialize model with DeepSpeed
            model_engine, optimizer, _, _ = deepspeed.initialize(
                model=self.model,
                model_parameters=self.model.parameters(),
                config=deepspeed_config
            )
            
            # Prepare training loop variables
            global_step = 0
            total_loss = 0
            best_loss = float('inf')
            total_steps = len(train_dataloader) * num_epochs
            
            print("Starting DeepSpeed training loop...")
            
            # Main training loop
            for epoch in range(num_epochs):
                print(f"\nStarting epoch {epoch+1}/{num_epochs}")
                epoch_loss = 0
                
                progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
                
                for step, batch in enumerate(progress_bar):
                    # Ensure data is on the correct device
                    device = next(model_engine.parameters()).device
                    batch = {k: v.to(device) for k, v in batch.items()}
                    
                    # Forward and backward pass
                    try:
                        # Forward pass
                        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):  # use mixed precision
                            outputs = model_engine(
                                input_ids=batch["input_ids"],
                                labels=batch["labels"]
                            )
                            loss = outputs.loss
                        
                        # Check for NaN loss
                        if torch.isnan(loss).any() or torch.isinf(loss).any():
                            print(f"Warning: Step {step} detected NaN/Inf loss ({loss.item()}), skipping this batch")
                            # Force memory cleanup
                            gc.collect()
                            torch.cuda.empty_cache()
                            continue
                        
                        # DeepSpeed backward
                        model_engine.backward(loss)
                        model_engine.step()
                        
                        # Force memory cleanup every step
                        if step % 2 == 0:  # more frequent cleanup
                            gc.collect()
                            torch.cuda.empty_cache()
                        
                        # Handle MoE load balancing loss
                        load_balancing_losses = []
                        for module in model_engine.module.modules():
                            if hasattr(module, 'lb_loss') and module.lb_loss is not None:
                                load_balancing_losses.append(module.lb_loss)
                        
                        # If there is load balancing loss, record but do not add directly
                        lb_loss_value = 0
                        if load_balancing_losses:
                            lb_loss_value = sum(load_balancing_losses).item() * 0.001
                        
                        # Accumulate loss
                        current_loss = loss.item()
                        total_loss += current_loss
                        epoch_loss += current_loss
                        
                        # Update progress bar
                        progress_bar.set_postfix({
                            "loss": current_loss,
                            "lb_loss": lb_loss_value if lb_loss_value > 0 else 0
                        })
                        
                        # Update global step
                        global_step += 1
                        
                        # Logging output
                        if global_step % logging_steps == 0:
                            avg_loss = total_loss / logging_steps
                            print(f"Step {global_step}/{total_steps} | Loss: {avg_loss:.4f}")
                            if lb_loss_value > 0:
                                print(f"Load balancing loss: {lb_loss_value:.6f}")
                            total_loss = 0
                        
                    except Exception as e:
                        print(f"Error during training step: {e}")
                        import traceback
                        traceback.print_exc()
                        if step == 0:  # If error on first step, may be fundamental issue
                            raise e
                        continue
                
                # Save checkpoint at end of each epoch
                epoch_save_path = os.path.join(self.output_dir, f"epoch-{epoch+1}")
                os.makedirs(epoch_save_path, exist_ok=True)
                
                # Save DeepSpeed checkpoint
                model_engine.save_checkpoint(epoch_save_path)
                
                # Save tokenizer
                self.tokenizer.save_pretrained(epoch_save_path)
                print(f"Epoch {epoch+1} complete, checkpoint saved to {epoch_save_path}")
                
                # Compute average loss
                avg_epoch_loss = epoch_loss / len(train_dataloader)
                print(f"Epoch {epoch+1} average loss: {avg_epoch_loss:.4f}")
                
                # Save best model
                if avg_epoch_loss < best_loss:
                    best_loss = avg_epoch_loss
                    best_model_path = os.path.join(self.output_dir, "best_model")
                    os.makedirs(best_model_path, exist_ok=True)
                    
                    # Save DeepSpeed checkpoint
                    model_engine.save_checkpoint(best_model_path)
                    
                    # Save tokenizer
                    self.tokenizer.save_pretrained(best_model_path)
                    print(f"New best model saved to {best_model_path}, loss: {best_loss:.4f}")
            
            # Save final model
            final_model_path = os.path.join(self.output_dir, "final_model")
            os.makedirs(final_model_path, exist_ok=True)
            
            # Save DeepSpeed checkpoint
            model_engine.save_checkpoint(final_model_path)
            
            # Save tokenizer
            self.tokenizer.save_pretrained(final_model_path)
            print(f"Training complete, final model saved to: {final_model_path}")
            
            return model_engine
        
        except Exception as e:
            print(f"Error during training: {str(e)}")
            import traceback
            traceback.print_exc()
            raise e

# Define logging callback class
class LoggingCallback(TrainerCallback):
    def __init__(self, log_interval=10):
        self.log_interval = log_interval
        self.step = 0
        
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and state.global_step > self.step + self.log_interval:
            self.step = state.global_step
            # Print key metrics
            if "loss" in logs:
                print(f"Step {state.global_step}: Loss = {logs['loss']:.4f}")

# Main function
def main():
    """Main function"""
    # Set environment variables for memory optimization
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    
    # Disable tokenizer parallelism to avoid fork warnings
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    # Disable DeepSpeed custom CUDA kernels
    os.environ["DS_BUILD_CPU_ADAM"] = "0"
    os.environ["DS_BUILD_FUSED_ADAM"] = "0"
    os.environ["DS_BUILD_TRANSFORMER"] = "0"
    os.environ["DS_BUILD_UTILS"] = "0"
    os.environ["DS_BUILD_SPARSE_ATTN"] = "0"
    
    # Get local rank - use DeepSpeed environment variable
    local_rank = int(os.environ.get("LOCAL_RANK", -1))
    
    # Set device
    if local_rank == -1:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        world_size = 1
        print(f"Using device: {device}")
    else:
        # Distributed training - initialize with DeepSpeed
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
        # Get world size
        if not dist.is_initialized():
            dist.init_process_group(backend="nccl")
        world_size = dist.get_world_size()
        print(f"Distributed training, using device: {device}, local_rank: {local_rank}, world_size: {world_size}")
    
    # Set per-GPU batch size and gradient accumulation steps
    micro_batch_size = 1  # micro batch size per GPU
    gradient_accumulation_steps = 16  # gradient accumulation steps
    
    # Compute global batch size = micro batch * grad accum * world size
    global_batch_size = micro_batch_size * gradient_accumulation_steps * world_size
    
    # Create custom DeepSpeed config - add optimizer config
    ds_config = {
        "train_batch_size": global_batch_size,
        "train_micro_batch_size_per_gpu": micro_batch_size,
        "gradient_accumulation_steps": gradient_accumulation_steps,
        "steps_per_print": 10,
        "gradient_clipping": 0.5,  # reduce gradient clipping threshold to help prevent NaN
        "fp16": {
            "enabled": False
        },
        "bf16": {
            "enabled": True
        },
        "optimizer": {
            "type": "AdamW",
            "params": {
                "lr": 1e-5,  # lower learning rate
                "betas": [0.9, 0.999],
                "eps": 1e-8,
                "weight_decay": 0.01
            }
        },
        "zero_optimization": {
            "stage": 2,
            "offload_optimizer": {
                "device": "cpu",
                "pin_memory": True,
                "buffer_count": 1,  # reduce buffer count
                "fast_init": True
            },
            "offload_param": {
                "device": "none"
            },
            "overlap_comm": True,
            "contiguous_gradients": True,
            "reduce_bucket_size": 1e7,  # further reduce communication buffer
            "allgather_bucket_size": 1e7,
            "reduce_scatter": True
        },
        "wall_clock_breakdown": False,
        "zero_allow_untested_optimizer": True
    }
    
    # Create fine-tuner
    finetuner = MoEFineTuner(
        base_model_path="/hy-tmp/llama-7b/",
        output_dir="./output_deepspeed",
        num_experts=2,           # only use 2 experts
        top_k=1,                 # only 1 expert per token
        capacity_factor=1.0,     
        use_load_balancing=True, 
        local_rank=local_rank    
    )
    
    # Start fine-tuning
    finetuner.supervised_finetuning(
        dataset_path="/hy-tmp/dense2MOE/sft_dataset_complete.csv",
        per_device_train_batch_size=1,      # minimal batch size
        gradient_accumulation_steps=16,     # more gradient accumulation
        num_epochs=2,
        learning_rate=1e-6,                 # very low learning rate
        warmup_steps=100,
        weight_decay=0.001,                 # smaller weight decay
        gradient_checkpointing=True,        # enable gradient checkpointing
        save_steps=500,
        logging_steps=10
    )

if __name__ == "__main__":
    main()

