# **********************
import copy
import logging
import os
from pathlib import Path
from functools import partial
import wandb
# wandb.login(key="7e22215eb5d3a686fb97dcb405e14683db4aa2c5")

import torch

from openpi.models.model import preprocess_observation_and_to_device
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from tqdm.auto import tqdm
from safetensors.torch import load_model as load_model_as_safetensor
from safetensors.torch import load_file

from openpi.models.pi0 import Pi0
from openpi.models.ema_model import EMAModel
from openpi.models.model import preprocess_observation
from openpi.utils import format_big_number
from openpi.training.config import TrainConfig, cli
from openpi.training.data_loader import create_data_loader
from openpi.training.utils import build_cosine_decay_schedule_with_wramup

import transformers
from transformers.models.gemma.modeling_gemma import GemmaAttention, GemmaMLP, GemmaDecoderLayer
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention, SiglipMLP

def main(config: TrainConfig):
    # Make one log on every process with the configuration for debugging.

    if config.checkpoint_dir is not None:
        os.makedirs(config.checkpoint_dir, exist_ok=True)


    # define model
    weight_dtype = torch.bfloat16


    model = Pi0(config.model)
    model.to(torch.device("cuda"))
    num_moe_params = sum(
        p.numel() for n, p in model.named_parameters() if "mlp" in n.lower() and "gemma_expert" in n.lower() and "layers" in n.lower() and ("experts" in n.lower() or "gate" in n.lower())
    )
    print(f"MoE 参数量: {num_moe_params/1e6:.2f} M")
    num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # print(f"Num total params = ({num_total_params})")
    print(f"Num learnable params= ({num_learnable_params})")
    assert 0 == 1
    if config.pretrained_model_name_or_path is not None:
        model_file = config.pretrained_model_name_or_path
        state_dict = torch.load(model_file)
        # assert 0 == 1
        missing, unexpected = model.load_state_dict(state_dict, strict=False)
        # load_model_as_safetensor(model, model_file, strict=True)


    # if config.enable_gradient_checkpointing:
    #     non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT)
    #     def check_fn(submodule: nn.Module) -> bool:
    #         return isinstance(submodule, GemmaAttention) or isinstance(submodule, GemmaMLP) or isinstance(submodule, SiglipSdpaAttention) or  isinstance(submodule, SiglipMLP)
    #         # return isinstance(submodule, GemmaDecoderLayer)
    #     apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)
    
    # emodel = copy.deepcopy(model)
    # ema_model = EMAModel(
    #     emodel,
    #     update_after_step=config.update_after_step,
    #     inv_gamma=config.inv_gamma,
    #     power=config.power,
    #     min_value=config.min_value,
    #     max_value=config.max_value
    # )

    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True

    optimizer_class = torch.optim.AdamW

    # Optimizer creation
    params_to_optimize = model.parameters()
    optimizer = optimizer_class(
        params_to_optimize,
        lr=config.optimizer_lr,
        betas=config.optimizer_betas,
        weight_decay=config.optimizer_weight_decay,
        eps=config.optimizer_eps,
    )
    lr_scheduler = build_cosine_decay_schedule_with_wramup(
                        optimizer, 
                        peak_lr=config.optimizer_lr, 
                        decay_lr=config.scheduler_decay_lr,
                        num_warmup_steps=config.scheduler_warmup_steps,
                        num_decay_steps=config.scheduler_decay_steps
                    )

    # prepare dataset
    data_loader, num_frames, num_episodes = create_data_loader(
        config,
        num_workers=config.num_workers,
        shuffle=True,
    )

    # emodel.to(accelerator.device, dtype=weight_dtype)    
    
    
    # Only show the progress bar once on each machine.
    global_step = 0
    progress_bar = tqdm(range(global_step, config.num_train_steps))
    progress_bar.set_description("Steps")

    loss_for_log = {}
    device_type = "cuda"
    model.train()
    while global_step < config.num_train_steps:
        for batch in data_loader:
            observation = batch[0]
            actions = batch[1].to(torch.device("cuda"),dtype=torch.bfloat16)
            observation = preprocess_observation_and_to_device(observation, train=True)
            with torch.autocast(device_type=device_type, dtype=weight_dtype):
                loss = model(observation["images"], observation["image_masks"], observation["tokenized_prompt"], observation["tokenized_prompt_mask"], observation["state"], actions)
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad(set_to_none=config.set_grads_to_none)
        
        # Checks if the accelerator has performed an optimization step behind the scenes
            progress_bar.update(1)
            global_step += 1

                # if global_step % config.checkpointing_period == 0:
                #     save_path = os.path.join(config.checkpoint_dir, f"checkpoint-{global_step}")
                #     accelerator.save_state(save_path)
                #     ema_save_path = os.path.join(save_path, f"ema")
                #     accelerator.save_model(emodel, ema_save_path)
                #     logger.info(f"Saved state to {save_path}")
            
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            # logs.update(loss_for_log)
            # # logger.info(logs)
            # accelerator.log(logs, step=global_step)

            if global_step >= config.num_train_steps:
                break        

    # Create the pipeline using using the trained modules and save it.
    # accelerator.wait_for_everyone()
    # if accelerator.is_main_process:
    #     model = accelerator.unwrap_model(model)
    #     model.config._save_pretrained(config.checkpoint_dir)
    #     torch.save(model.state_dict(), os.path.join(config.checkpoint_dir, 'final_model.pth'))
    #     ema_save_path = os.path.join(config.checkpoint_dir, f"ema")
    #     accelerator.save_model(emodel, ema_save_path)
        
    #     logger.info(f"Saved Model to {config.checkpoint_dir}")

        # if args.push_to_hub:
        #     save_model_card(
        #         repo_id,
        #         base_model=args.pretrained_model_name_or_path,
        #         repo_folder=args.output_dir,
        #     )
        #     upload_folder(
        #         repo_id=repo_id,
        #         folder_path=args.output_dir,
        #         commit_message="End of training",
        #         token=args.hub_token,
        #         allow_patterns=["pytorch_model.bin", "*.json", "*.md"],
        #         # ignore_patterns=["step_*", "epoch_*"],
        #     )
            
    # accelerator.end_training()


if __name__ == "__main__":
    main(cli())
