# **********************
import copy
import logging
import os
from pathlib import Path
from functools import partial
import wandb
wandb.login(key="7e22215eb5d3a686fb97dcb405e14683db4aa2c5")

import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    CheckpointImpl,
    apply_activation_checkpointing,
    checkpoint_wrapper,
)

from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin, ProjectConfiguration, set_seed
from accelerate.logging import get_logger

from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from tqdm.auto import tqdm
from safetensors.torch import load_model as load_model_as_safetensor

from openpi.models.pi0 import Pi0
from openpi.models.model import preprocess_observation
from openpi.utils import format_big_number
from openpi.training.config import PretrainConfig, cli
from openpi.training.data_loader import create_pretrain_data_loader
from openpi.training.utils import build_cosine_decay_schedule_with_wramup

import transformers
from transformers.models.gemma.modeling_gemma import GemmaAttention, GemmaMLP, GemmaDecoderLayer
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention, SiglipMLP

def main(config: PretrainConfig):
    logger = get_logger(__name__)
    logging_dir = Path(config.checkpoint_dir, config.logging_dir)
    accelerator_project_config = ProjectConfiguration(total_limit=config.checkpoints_total_limit)
    accelerator = Accelerator(
        deepspeed_plugin=DeepSpeedPlugin(
            hf_ds_config=config.deepspeed
        ) if config.deepspeed is not None else None,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        mixed_precision=config.mixed_precision,
        log_with=config.report_to,
        project_dir=logging_dir,
        project_config=accelerator_project_config,
    )
    accelerator.init_trackers(project_name = config.exp_name)
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
    else:
        transformers.utils.logging.set_verbosity_error()
    
    # If passed along, set the training seed now.
    if config.seed is not None:
        set_seed(config.seed)
    
    # Handle the repository creation
    if accelerator.is_main_process:
        if config.checkpoint_dir is not None:
            os.makedirs(config.checkpoint_dir, exist_ok=True)

        # if config.push_to_hub:
        #     repo_id = create_repo(
        #         repo_id=config.hub_model_id or Path(config.checkpoint_dir).name, exist_ok=True, token=config.hub_token
        #     ).repo_id
    

    # define model
    weight_dtype = torch.bfloat16

    model = Pi0(config.model)
    # cal num total params and learnable params
    num_total_params = sum(p.numel() for p in model.parameters())
    num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    if config.pretrained_model_name_or_path is not None:
        logger.info("Constructing model from pretrained checkpoint.")
        model_file = config.pretrained_model_name_or_path
        state_dict = torch.load(model_file)
        # assert 0 == 1
        missing, unexpected = model.load_state_dict(state_dict, strict=False)
        # load_model_as_safetensor(model, model_file, strict=True)
    

    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
    # which ensure saving model in huggingface format (config.json + pytorch_model.bin)
    def save_model_hook(models, weights, output_dir):
        if accelerator.is_main_process:
            for model in models:
                model_to_save = model.module if hasattr(model, "module") else model  # type: ignore
                if isinstance(model_to_save, type(accelerator.unwrap_model(model))):
                    model_to_save.save_pretrained(output_dir)

    accelerator.register_save_state_pre_hook(save_model_hook)

    if config.enable_gradient_checkpointing:
        non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT)
        def check_fn(submodule: nn.Module) -> bool:
            return isinstance(submodule, GemmaAttention) or isinstance(submodule, GemmaMLP) or isinstance(submodule, SiglipSdpaAttention) or  isinstance(submodule, SiglipMLP)
            # return isinstance(submodule, GemmaDecoderLayer)
        apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)

    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True

    optimizer_class = torch.optim.AdamW

    # Optimizer creation
    params_to_optimize = model.parameters()
    optimizer = optimizer_class(
        params_to_optimize,
        lr=config.optimizer_lr,
        betas=config.optimizer_betas,
        weight_decay=config.optimizer_weight_decay,
        eps=config.optimizer_eps,
    )
    lr_scheduler = build_cosine_decay_schedule_with_wramup(
                        optimizer, 
                        peak_lr=config.optimizer_lr, 
                        decay_lr=config.scheduler_decay_lr,
                        num_warmup_steps=config.scheduler_warmup_steps,
                        num_decay_steps=config.scheduler_decay_steps
                    )

    # prepare dataset
    data_loader, num_frames, num_episodes = create_pretrain_data_loader(
        config,
        num_workers=config.num_workers,
        shuffle=True,
    )

    # Prepare everything with our `accelerator`.
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, data_loader, lr_scheduler                   
    )


    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        config_dict = {k: v for k, v in vars(config).items() if k != 'total_configs'}
        accelerator.init_trackers("roboticDiffusionTransformer", config=config_dict)

    # Train!
    total_batch_size = config.batch_size * accelerator.num_processes * config.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num episodes each epoch = {num_episodes}")
    logger.info(f"  Num frames each epoch = {num_frames}")
    logger.info(f"  Num train steps= ({config.num_train_steps})")

    logger.info(f"  Instantaneous batch size per device = {config.batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {config.gradient_accumulation_steps}")

    logger.info(f"  Num total params = ({format_big_number(num_total_params)})")
    logger.info(f"  Num learnable params= ({format_big_number(num_learnable_params)})")

    global_step = 0
    resume_step = 0
    # Potentially load in the weights and states from a previous save
    if config.resume_from_checkpoint:
        if config.resume_from_checkpoint != "latest":
            path = os.path.basename(config.resume_from_checkpoint)
        else:
            # Get the mos recent checkpoint
            dirs = os.listdir(config.checkpoint_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None
        if path is None:
            accelerator.print(
                f"Checkpoint '{config.resume_from_checkpoint}' does not exist. Starting a new training run."
            )
            config.resume_from_checkpoint = None
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            try:
                accelerator.load_state(os.path.join(config.checkpoint_dir, path)) # load_module_strict=False
            except:
                # load deepspeed's state_dict
                logger.info("Resuming training state failed. Attempting to only load from model checkpoint.")
                checkpoint = torch.load(os.path.join(config.checkpoint_dir, path, "pytorch_model", "mp_rank_00_model_states.pt"))
                model.module.load_state_dict(checkpoint["module"])
                
            global_step = int(path.split("-")[1])
            
            resume_step = (global_step * config.gradient_accumulation_steps) % num_frames 
    
    
    # Only show the progress bar once on each machine.
    global_step = resume_step
    progress_bar = tqdm(range(global_step, config.num_train_steps), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Steps")

    loss_for_log = {}
    device_type = "cuda" if "cuda" in str(accelerator.device) else "cpu"
    average_loss = 0.0
    model.train()
    while global_step < config.num_train_steps:
        for batch in train_dataloader:
            with accelerator.accumulate(model):
                observation = batch[0]
                actions = batch[1]
                observation = preprocess_observation(observation, train=True)
                with torch.autocast(device_type=device_type, dtype=weight_dtype):
                    loss = model(observation["images"], observation["image_masks"], observation["tokenized_prompt"], observation["tokenized_prompt_mask"], observation["state"], actions)
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    params_to_clip = model.parameters()
                    accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=config.set_grads_to_none)
                average_loss += loss.detach().item()
            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1

                if global_step % config.checkpointing_period == 0 :
                    save_path = os.path.join(config.checkpoint_dir, f"checkpoint-{global_step}")
                    os.makedirs(save_path, exist_ok=True)
                    # accelerator.save_state(save_path)
                    unwarp_model = accelerator.unwrap_model(model)
                    torch.save(unwarp_model.state_dict(), os.path.join(save_path, 'pytorch_model.pth'))
                    logger.info(f"Saved state to {save_path}")
            
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            logs.update(loss_for_log)
            if global_step % 50 == 0:
                logs = {"loss": average_loss/(50*config.gradient_accumulation_steps), "lr": lr_scheduler.get_last_lr()[0]}
                # logger.info(logs)
                accelerator.log(logs, step=global_step)
                average_loss = 0.0

            if global_step >= config.num_train_steps:
                break        

    # Create the pipeline using using the trained modules and save it.
    accelerator.wait_for_everyone()
    # if accelerator.is_main_process:
    torch.cuda.empty_cache()
    accelerator.save_state(os.path.join(config.checkpoint_dir, "accelerator"))
    logger.info(f"Saved Model to {config.checkpoint_dir}")

        # if args.push_to_hub:
        #     save_model_card(
        #         repo_id,
        #         base_model=args.pretrained_model_name_or_path,
        #         repo_folder=args.output_dir,
        #     )
        #     upload_folder(
        #         repo_id=repo_id,
        #         folder_path=args.output_dir,
        #         commit_message="End of training",
        #         token=args.hub_token,
        #         allow_patterns=["pytorch_model.bin", "*.json", "*.md"],
        #         # ignore_patterns=["step_*", "epoch_*"],
        #     )
    torch.cuda.empty_cache()
    accelerator.end_training()


if __name__ == "__main__":
    main(cli())
