# **********************
import copy
import logging
import os
from pathlib import Path
from functools import partial
import wandb
# wandb.login(key="7e22215eb5d3a686fb97dcb405e14683db4aa2c5")

import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    CheckpointImpl,
    apply_activation_checkpointing,
    checkpoint_wrapper,
)
from transformers.optimization import get_cosine_schedule_with_warmup
import torch.distributed as dist
import torch.nn as nn

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    CheckpointImpl,
    apply_activation_checkpointing,
    checkpoint_wrapper,
)

from torch.distributed.fsdp import (
    MixedPrecision,
    ShardingStrategy,
)
from torch.optim import AdamW
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy, size_based_auto_wrap_policy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from tqdm.auto import tqdm
from safetensors.torch import load_model as load_model_as_safetensor

from openpi.models.pi0 import Pi0
from openpi.models.ema_model import EMAModel
from openpi.models.model import preprocess_observation_and_to_device
from openpi.utils import format_big_number
from openpi.training.config import PretrainConfig, cli
from openpi.training.data_loader import create_pretrain_data_loader
from openpi.training.utils import build_cosine_decay_schedule_with_wramup

import transformers
from transformers.models.gemma.modeling_gemma import GemmaAttention, GemmaMLP, GemmaDecoderLayer
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention, SiglipMLP
from transformers import PaliGemmaForConditionalGeneration

def setup_distributed():
    dist.init_process_group(
        backend="nccl",  # GPU 通信推荐使用 nccl
        init_method="env://",  # 让 torchrun 传递环境变量
    )
    local_rank = int(os.environ["LOCAL_RANK"])  # 获取本地进程编号
    torch.cuda.set_device(local_rank)  # 绑定到对应 GPU
    print(f"Process {dist.get_rank()} initialized on GPU {torch.cuda.current_device()}")

def main(config: PretrainConfig):
    setup_distributed()
    torch.cuda.empty_cache()

    # init_logging()

    # Check device is available
    local_rank = int(os.environ["LOCAL_RANK"])
    device = torch.device(f'cuda:{local_rank}')
    world_size = dist.get_world_size()
    
    # If passed along, set the training seed now.
    os.makedirs(config.checkpoint_dir, exist_ok=True)

        # if config.push_to_hub:
        #     repo_id = create_repo(
        #         repo_id=config.hub_model_id or Path(config.checkpoint_dir).name, exist_ok=True, token=config.hub_token
        #     ).repo_id
    

    # define model
    weight_dtype = torch.float32

    model = Pi0(config.model)
    model.to(dtype=weight_dtype)
    # cal num total params and learnable params
    num_total_params = sum(p.numel() for p in model.parameters())
    num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # set fsdp sharding strategy 
    sharding_strategy = "full-shard"
    # FSDP-Specific Parameters
    if sharding_strategy == "shard-grad-op":
        fsdp_sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
    elif sharding_strategy == "full-shard":
        fsdp_sharding_strategy = ShardingStrategy.HYBRID_SHARD
    else:
        raise ValueError(f"FSDP Sharding Strategy {sharding_strategy} is not supported!")


    # # set MixedPrecision
    # config.enable_mixed_precision_training = None
    # if config.enable_mixed_precision_training and config.mixed_precision_dtype == torch.bfloat16:
    #     reduce_buffer_dtype = torch.bfloat16 if not config.reduce_in_full_precision else torch.float32
    #     fsdp_precision_policy = MixedPrecision(
    #         param_dtype=torch.bfloat16, reduce_dtype=reduce_buffer_dtype, buffer_dtype=reduce_buffer_dtype
    #     )
    # else:
    fsdp_precision_policy = MixedPrecision(
        param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16
    )

    # get my_auto_wrap_policy
    my_auto_wrap_policy = partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            Pi0,
        },
    )
    # my_auto_wrap_policy = partial(
    #     size_based_auto_wrap_policy, min_num_params=1e8,
    # )
    # my_auto_wrap_policy = transformer_auto_wrap_policy({Pi0})
    # apply policy for fsdo
    model = FSDP(
        model,
        sharding_strategy=fsdp_sharding_strategy,
        auto_wrap_policy=my_auto_wrap_policy,
        mixed_precision=fsdp_precision_policy,
        # mixed_precision=None,
        device_id=torch.cuda.current_device(),
        limit_all_gathers=True,
        use_orig_params=True,
        sync_module_states=True,
    )

    if config.enable_gradient_checkpointing:
        non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT)
        def check_fn(submodule: nn.Module) -> bool:
            return isinstance(submodule, GemmaAttention) or isinstance(submodule, GemmaMLP) or isinstance(submodule, SiglipSdpaAttention) or  isinstance(submodule, SiglipMLP)
            # return isinstance(submodule, GemmaDecoderLayer)
        apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)
    
    dist.barrier()

    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True

    #   => Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed!
    decay, no_decay = [], []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        # Check on any parameters with fewer than 2 dimensions or with "bias" in the name
        if param.ndim <= 1 or name.endswith(".bias"):
            no_decay.append(param)
        else:
            decay.append(param)

    # Build Parameter Groups
    groups = [{"params": decay, "weight_decay": config.optimizer_weight_decay}, {"params": no_decay, "weight_decay": 0.0}]

    # Create Optimizer & LR Scheduler
    optimizer = AdamW(groups, 
                    lr=config.optimizer_lr, 
                    betas=config.optimizer_betas,
                    eps=config.optimizer_eps,
                )
    
    # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05)
    # num_warmup_steps = int(config.num_train_steps * config.warmup_ratio)
    num_warmup_steps = config.scheduler_decay_steps
    lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, config.num_train_steps)
    for param_group in optimizer.param_groups:
        param_group["lr"] = 0.0

    # prepare dataset
    data_loader, num_frames, num_episodes = create_pretrain_data_loader(
        config,
        num_workers=config.num_workers,
        shuffle=True,
    )

    global_step = 0
    resume_step = 0
    
    
    # Only show the progress bar once on each machine.
    global_step = resume_step
    progress_bar = tqdm(range(global_step, config.num_train_steps), disable=not local_rank==0)
    progress_bar.set_description("Steps")

    loss_for_log = {}
    device_type = "cuda" 
    average_loss = 0.0
    model.train()
    while global_step < config.num_train_steps:
        for batch in data_loader:
            observation = batch[0]
            actions = batch[1]
            observation = preprocess_observation_and_to_device(observation, train=True)
            with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
                loss = model(observation["images"], observation["image_masks"], observation["tokenized_prompt"], observation["tokenized_prompt_mask"], observation["state"], actions)
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad(set_to_none=config.set_grads_to_none)
            average_loss += loss.detach().item()
            # Checks if the accelerator has performed an optimization step behind the scenes
            progress_bar.update(1)
            global_step += 1

            if global_step % config.checkpointing_period == 0:
                save_path = os.path.join(config.checkpoint_dir, f"checkpoint-{global_step}")
                os.makedirs(save_path, exist_ok=True)
                # accelerator.save_state(save_path)
                
                torch.save(model.state_dict(), os.path.join(save_path, 'pytorch_model.pth'))
                ema_save_path = os.path.join(save_path, f"ema")
            
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            logs.update(loss_for_log)
            if global_step % 50 == 0:
                logs = {"loss": average_loss/(50*config.gradient_accumulation_steps), "lr": lr_scheduler.get_last_lr()[0]}
                # logger.info(logs)
                # accelerator.log(logs, step=global_step)
                average_loss = 0.0

            if global_step >= config.num_train_steps:
                break        


        # if args.push_to_hub:
        #     save_model_card(
        #         repo_id,
        #         base_model=args.pretrained_model_name_or_path,
        #         repo_folder=args.output_dir,
        #     )
        #     upload_folder(
        #         repo_id=repo_id,
        #         folder_path=args.output_dir,
        #         commit_message="End of training",
        #         token=args.hub_token,
        #         allow_patterns=["pytorch_model.bin", "*.json", "*.md"],
        #         # ignore_patterns=["step_*", "epoch_*"],
        #     )


if __name__ == "__main__":
    main(cli())
