import os
import sys
import time
import json
import wandb
import warnings
import pandas as pd
import pyarrow as pa

# Add the environment variable to skip DeepSpeed CUDA check
os.environ["DS_SKIP_CUDA_CHECK"] = "1"

from tqdm import tqdm
from datasets import Dataset

import torch
import torch.distributed as dist

from transformers import TrainingArguments, TrainerCallback
from transformers import AutoModelForCausalLM, AutoTokenizer

import trl
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, SFTConfig

import peft
from peft import LoraConfig

from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.utils import DeepSpeedPlugin


def load_hfmodel(ckpt=None):
    if ckpt == None:
        path = "    "
    else:
        path = ckpt

    base_model = AutoModelForCausalLM.from_pretrained(
        path,
        # device_map="auto",
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        use_auth_token=True,
        # attn_implementation="flash_attention_2",
        attn_implementation="sdpa",
    )
    base_model.config.use_cache = False
    base_model.config.pretraining_tp = 1
    # base_model = base_model.float()
    base_model = base_model.to(torch.float16)

    tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.add_eos_token = True
    print("Loaded Model and Tokenizer")

    return base_model, tokenizer


class PeftSavingCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_path = os.path.join(
            args.output_dir, f"checkpoint-{state.global_step}"
        )
        kwargs["model"].save_pretrained(checkpoint_path)

        if "pytorch_model.bin" in os.listdir(checkpoint_path):
            os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))


class DeepSpeedInfoCallback(TrainerCallback):
    def on_train_begin(self, args, state, control, **kwargs):
        # Run on all processes
        trainer = kwargs.get("trainer", None)
        if trainer and hasattr(trainer, "deepspeed") and trainer.deepspeed:
            # Output rank information for this process
            if dist.is_initialized():
                local_rank = int(os.environ.get("LOCAL_RANK", "0"))
                global_rank = dist.get_rank()
                world_size = dist.get_world_size()

                # Current process information
                process_info = {
                    "local_rank": local_rank,
                    "global_rank": global_rank,
                    "world_size": world_size,
                    "gpu_name": torch.cuda.get_device_name(local_rank),
                    "memory_allocated_GB": torch.cuda.memory_allocated(local_rank)
                    / 1e9,
                    "memory_reserved_GB": torch.cuda.memory_reserved(local_rank) / 1e9,
                }

                # Save each process information to log file
                with open(f"deepspeed_process_{global_rank}.json", "w") as f:
                    json.dump(process_info, f, indent=2)

                print(
                    f"Process {global_rank}/{world_size} (Local rank: {local_rank}): Data parallel worker initialized"
                )

                # Data parallel group information
                if hasattr(trainer.deepspeed, "data_parallel_group"):
                    dp_size = torch.distributed.get_world_size(
                        group=trainer.deepspeed.data_parallel_group
                    )
                    dp_rank = torch.distributed.get_rank(
                        group=trainer.deepspeed.data_parallel_group
                    )
                    print(
                        f"Process {global_rank}: Data parallel size: {dp_size}, Data parallel rank: {dp_rank}"
                    )

                # Model sharding information
                if hasattr(trainer.deepspeed, "module") and hasattr(
                    trainer.deepspeed.module, "tied_comms"
                ):
                    print(
                        f"Process {global_rank}: ZeRO-3 model shard info: {trainer.deepspeed.module.tied_comms.keys()}"
                    )

                # Gather all process information (main process only)
                if global_rank == 0:
                    # Add barrier for synchronization
                    torch.distributed.barrier()
                    print(
                        f"\n=== DeepSpeed distributed environment info (total {world_size} processes) ==="
                    )
                    print(f"- Data Parallelism size: {world_size}")
                    print(f"- Batch size per GPU: {args.per_device_train_batch_size}")
                    print(
                        f"- Accumulated batch size: {args.per_device_train_batch_size * args.gradient_accumulation_steps * world_size}"
                    )
                    print(
                        f"- ZeRO optimization stage: {trainer.deepspeed.config.get('zero_optimization', {}).get('stage', 'None')}"
                    )

                    # GPU memory usage information (master node only)
                    print("\nGPU memory usage (master node only):")
                    for i in range(torch.cuda.device_count()):
                        print(
                            f"  GPU {i}: {torch.cuda.memory_allocated(i) / 1e9:.2f} GB / {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB"
                        )


def main(cfg):
    # Output debugging information
    print(f"Micro batch size per GPU: {cfg.per_gpu_bsz}")
    print(f"Gradient accumulation steps: {cfg.gradient_accumulation_steps}")
    print(f"Process batch size: {cfg.per_gpu_bsz * cfg.gradient_accumulation_steps}")

    # DeepSpeed ZeRO-3 settings with auto batch size
    deepspeed_config = {
        "zero_optimization": {
            "stage": 3,
            "offload_optimizer": {"device": "cpu", "pin_memory": True},
            "offload_param": {"device": "cpu", "pin_memory": True},
            "overlap_comm": True,
            "contiguous_gradients": True,
            "reduce_bucket_size": 5e8,
            "stage3_prefetch_bucket_size": 5e8,
            "stage3_param_persistence_threshold": 1e6,
        },
        "fp16": {
            "enabled": True,
            "loss_scale": 0,
            "loss_scale_window": 1000,
            "initial_scale_power": 16,
            "hysteresis": 2,
            "min_loss_scale": 1,
        },
        "gradient_accumulation_steps": "auto",  # Set to auto
        "gradient_clipping": 1.0,
        "steps_per_print": 10,
        "train_batch_size": "auto",  # Set to auto
        "train_micro_batch_size_per_gpu": "auto",  # Set to auto
    }

    # Create DeepSpeedPlugin instance
    deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=deepspeed_config)

    # Initialize accelerator with DeepSpeed plugin
    accelerator = Accelerator(
        gradient_accumulation_steps=cfg.gradient_accumulation_steps,
        log_with="wandb",
        deepspeed_plugin=deepspeed_plugin,
    )

    # Set seed for reproducibility
    set_seed(cfg.seed)

    if accelerator.is_main_process:
        os.makedirs(cfg.ckpt_dir, exist_ok=True)
        accelerator.init_trackers(
            project_name="project_name",
            config=cfg,
            init_kwargs={"wandb": {"name": f"run_name"}},
        )

    # load model and tokenizer
    base_model, tokenizer = load_hfmodel(cfg.model_name)

    # Load CSV data
    train_df = pd.read_csv(cfg.train_path)

    # Convert to Dataset and tokenize
    train_dataset = Dataset(pa.Table.from_pandas(train_df))

    training_args = SFTConfig(
        optim="adamw_torch",  # Recommended optimizer when using DeepSpeed
        lr_scheduler_type="cosine",
        output_dir=cfg.ckpt_dir,
        per_device_train_batch_size=cfg.per_gpu_bsz,
        per_device_eval_batch_size=cfg.per_gpu_bsz,
        fp16=True,
        bf16=False,
        gradient_accumulation_steps=cfg.gradient_accumulation_steps,
        gradient_checkpointing=True,
        learning_rate=cfg.lr,
        logging_steps=cfg.logging_steps,
        num_train_epochs=cfg.n_epochs,
        warmup_ratio=cfg.warmup_ratio,
        weight_decay=cfg.weight_decay,
        report_to="wandb",
        save_strategy="steps",
        save_steps=cfg.save_steps,
        # eval_on_start = True,
        # evaluation_strategy="epoch",
        seed=cfg.seed,
        group_by_length=True,
        dataset_text_field="history",
        max_seq_length=2**13,
        resume_from_checkpoint=(
            cfg.resume_checkpoint if hasattr(cfg, "resume_checkpoint") else None
        ),
        ddp_find_unused_parameters=False,
        deepspeed=deepspeed_config,  # Add DeepSpeed configuration
    )

    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
    )

    sep_tokens = tokenizer.encode("<|im_start|>assistant\n")[1:]
    data_collator = DataCollatorForCompletionOnlyLM(
        response_template=sep_tokens, tokenizer=tokenizer
    )

    trainer = SFTTrainer(
        model=base_model,
        train_dataset=train_dataset,
        # eval_dataset=val_dataset,
        # dataset_text_field="text", # moved to SFTConfig
        # max_seq_length=2**13, # moved to SFTConfig
        tokenizer=tokenizer,
        args=training_args,
        data_collator=data_collator,
        callbacks=[PeftSavingCallback],
        peft_config=peft_config,
    )

    # Explicitly prepare for accelerate
    with accelerator.main_process_first():
        if accelerator.is_main_process:
            print("Set Trainer")
            print("Start Training!")

    # Output distributed GPU information
    if accelerator.is_main_process:
        if torch.cuda.is_available():
            print(f"Total available GPUs: {torch.cuda.device_count()}")
            for i in range(torch.cuda.device_count()):
                print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
                print(
                    f"Memory allocated: {torch.cuda.memory_allocated(i) / 1e9:.2f} GB"
                )
                print(f"Memory reserved: {torch.cuda.memory_reserved(i) / 1e9:.2f} GB")

        print(f"Local rank: {accelerator.local_process_index}")
        print(f"Global rank: {accelerator.process_index}")
        print(f"Num processes: {accelerator.num_processes}")
        print(f"Is distributed: {accelerator.distributed_type != 'NO'}")
        print(f"Distribution type: {accelerator.distributed_type}")

        # Output environment variable information
        print(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK', 'Not set')}")
        print(f"RANK: {os.environ.get('RANK', 'Not set')}")
        print(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE', 'Not set')}")

    # Add DeepSpeed callback
    trainer.add_callback(DeepSpeedInfoCallback())

    # Explicitly transfer to train function to show resuming training
    trainer.train()
    # trainer.train(resume_from_checkpoint=cfg.resume_checkpoint)

    # End tracking
    accelerator.end_training()


if __name__ == "__main__":
    # Initialize wandb only from main process when using accelerate
    if "RANK" not in os.environ or os.environ["RANK"] == "0":
        run = wandb.init(
            project="project_name",
            name=f"run_name",
        )

        # config
        cfg = wandb.config
    else:
        cfg = type("obj", (object,), {})

    cfg.seed = 1

    ### Set pretrained model name or checkpoint path
    cfg.model_name = "example_model_name"
    ### Set training dataset path
    cfg.train_path = f"example_file.csv"
    ### Set saving directory
    cfg.ckpt_dir = f"./example_save_dir"
    ### Set resume checkpoint path
    cfg.resume_checkpoint = (
        ""  # If you want to load optimizer state, lr_scheduler state, etc.
    )

    ### Set training parameters
    cfg.per_gpu_bsz = 4
    cfg.gradient_accumulation_steps = 8
    cfg.lr = 1e-4
    cfg.logging_steps = 1
    cfg.n_epochs = 300
    cfg.weight_decay = 1.0
    cfg.warmup_ratio = 0.01
    cfg.eval_steps = 50
    cfg.save_steps = 50

    main(cfg)
