if __name__ == "__main__":
    import sys
    import os
    import pathlib

    ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
    sys.path.append(ROOT_DIR)
    os.chdir(ROOT_DIR)

import os
import hydra
import torch
from omegaconf import OmegaConf
import pathlib
from torch.utils.data import DataLoader
import copy
import random
import tqdm
import numpy as np
import pickle

from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.workspace.base_workspace import BaseWorkspace
from diffusion_policy.policy.diffusion_transformer_timm_policy import DiffusionTransformerTimmPolicy
from diffusion_policy.dataset.base_dataset import BaseImageDataset, BaseDataset
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
from diffusion_policy.common.checkpoint_util import TopKCheckpointManager
from diffusion_policy.common.json_logger import JsonLogger
from diffusion_policy.model.diffusion.ema_model import EMAModel
from diffusion_policy.model.common.lr_scheduler import get_scheduler
from accelerate import Accelerator

OmegaConf.register_new_resolver("eval", eval, replace=True)

# %%
class TrainDiffusionTransformerTimmWorkspace_di(BaseWorkspace):
    include_keys = ['global_step', 'epoch']

    def __init__(self, cfg: OmegaConf):
        super().__init__(cfg)

        # set seed
        seed = cfg.training.seed
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

        # configure model
        self.model: DiffusionTransformerTimmPolicy
        self.model = hydra.utils.instantiate(cfg.policy)

        self.ema_model: DiffusionTransformerTimmPolicy = None
        if cfg.training.use_ema:
            self.ema_model = copy.deepcopy(self.model)

        # configure training state
        if not cfg.training.resume:
            self.optimizer = self.model.get_optimizer(**cfg.optimizer)
        else:
            obs_encorder_lr = cfg.optimizer.lr
            if cfg.policy.obs_encoder.pretrained:
                obs_encorder_lr *= 0.1
                print('==> reduce pretrained obs_encorder\'s lr')
            obs_encorder_params = list()
            for param in self.model.obs_encoder.parameters():
                if param.requires_grad:
                    obs_encorder_params.append(param)
            print(f'obs_encorder params: {len(obs_encorder_params)}')

            # self.optimizer = hydra.utils.instantiate(
            #     cfg.optimizer, params=param_groups)
            optimizer_cfg = OmegaConf.to_container(cfg.optimizer, resolve=True)
            optimizer_cfg.pop('_target_')

            param_groups = [
                {'params': self.model.model.parameters(), 'initial_lr': optimizer_cfg['lr']},
                {'params': obs_encorder_params, 'lr': obs_encorder_lr, 'initial_lr': obs_encorder_lr}
            ]

            self.optimizer = torch.optim.AdamW(
                params=param_groups,
                **optimizer_cfg
            )
        self.global_step = 0
        self.epoch = 0
        
        # do not save optimizer if resume=False
        if not cfg.training.resume:
            self.exclude_keys = ['optimizer']

        if cfg.policy.di:
            print("use Di!!!!!!!!!!!!!!!!!")
            self.use_di = True
            self.num_experts = cfg.policy.num_experts
    
    def run(self):
        cfg = copy.deepcopy(self.cfg)
        
        accelerator = Accelerator(log_with='wandb', mixed_precision='bf16')
        wandb_cfg = OmegaConf.to_container(cfg.logging, resolve=True)
        wandb_cfg.pop('project')

        accelerator.init_trackers(
            project_name=cfg.logging.project,
            config=OmegaConf.to_container(cfg, resolve=True),
            init_kwargs={"wandb": wandb_cfg}
        )

        # resume training
        if cfg.training.resume:
            lastest_ckpt_path = ''
            print(f"Resuming from checkpoint {lastest_ckpt_path}")
            self.load_checkpoint(path=lastest_ckpt_path)

        # configure dataset
        dataset: BaseImageDataset
        dataset = hydra.utils.instantiate(cfg.task.dataset)
        assert isinstance(dataset, BaseImageDataset) or isinstance(dataset, BaseDataset)
        train_dataloader = DataLoader(dataset, **cfg.dataloader)

        # ==============================
        if cfg.training.debug_nan:

            def detect_nan_forward(module, input, output):
                if torch.is_tensor(output) and (torch.isnan(output).any() or torch.isinf(output).any()):
                    print(f"[NaN DETECTED - Forward] {module.__class__.__name__}")
                    if isinstance(input, (tuple, list)):
                        for i, t in enumerate(input):
                            if torch.is_tensor(t):
                                print(f"  Input[{i}] min={t.min().item():.4e}, max={t.max().item():.4e}")
                    print(f"  Output min={output.min().item():.4e}, max={output.max().item():.4e}")

            def detect_nan_backward(module, grad_input, grad_output):
                for g in grad_output:
                    if torch.is_tensor(g) and (torch.isnan(g).any() or torch.isinf(g).any()):
                        print(f"[NaN DETECTED - Backward] {module.__class__.__name__}")
                        print(f"  grad_output min={g.min().item():.4e}, max={g.max().item():.4e}")

            for name, module in self.model.named_modules():
                module.register_forward_hook(detect_nan_forward)
                module.register_full_backward_hook(detect_nan_backward)
        # ==============================

        # compute normalizer on the main process and save to disk
        # normalizer_path = os.path.join(self.output_dir, 'normalizer.pkl')
        if accelerator.is_main_process:
            normalizer = dataset.get_normalizer()
            pickle.dump(normalizer, open(normalizer_path, 'wb'))

            normalizer_path = os.path.join(self.output_dir, 'normalizer.pkl')

        # load normalizer on all processes
        accelerator.wait_for_everyone()
        normalizer = pickle.load(open(normalizer_path, 'rb'))

        # # configure validation dataset
        # val_dataset = dataset.get_validation_dataset()
        # val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader)

        self.model.set_normalizer(normalizer)
        if cfg.training.use_ema:
            self.ema_model.set_normalizer(normalizer)

        # configure lr scheduler
        lr_scheduler = get_scheduler(
            cfg.training.lr_scheduler,
            optimizer=self.optimizer,
            num_warmup_steps=cfg.training.lr_warmup_steps,
            num_training_steps=(
                len(train_dataloader) * cfg.training.num_epochs) \
                    // cfg.training.gradient_accumulate_every,
            # pytorch assumes stepping LRScheduler every epoch
            # however huggingface diffusers steps it every batch
            last_epoch=self.global_step-1
        )

        # configure ema
        ema: EMAModel = None
        if cfg.training.use_ema:
            ema = hydra.utils.instantiate(
                cfg.ema,
                model=self.ema_model)

        # configure env
        env_runner: BaseImageRunner
        env_runner = hydra.utils.instantiate(
            cfg.task.env_runner,
            output_dir=self.output_dir)
        assert isinstance(env_runner, BaseImageRunner)

        # # configure logging
        # wandb_run = wandb.init(
        #     dir=str(self.output_dir),
        #     config=OmegaConf.to_container(cfg, resolve=True),
        #     **cfg.logging
        # )
        # wandb.config.update(
        #     {
        #         "output_dir": self.output_dir,
        #     }
        # )

        # configure checkpoint
        topk_manager = TopKCheckpointManager(
            save_dir=os.path.join(self.output_dir, 'checkpoints'),
            **cfg.checkpoint.topk
        )

        # device transfer
        # device = torch.device(cfg.training.device)
        # self.model.to(device)
        # if self.ema_model is not None:
        #     self.ema_model.to(device)
        # optimizer_to(self.optimizer, device)

        # accelerator
        self.train_dataloader, self.model, self.optimizer, lr_scheduler = accelerator.prepare(
            train_dataloader, self.model, self.optimizer, lr_scheduler
        )
        device = self.model.device
        if self.ema_model is not None:
            self.ema_model.to(device)


        # save batch for sampling
        train_sampling_batch = None

        if cfg.training.debug:
            cfg.training.num_epochs = 2
            cfg.training.max_train_steps = 3
            cfg.training.max_val_steps = 3
            cfg.training.rollout_every = 1
            cfg.training.checkpoint_every = 1
            cfg.training.val_every = 1
            cfg.training.sample_every = 1

        # training loop
        log_path = os.path.join(self.output_dir, 'logs.json.txt')
        with JsonLogger(log_path) as json_logger:
            for local_epoch_idx in range(cfg.training.num_epochs):
                self.model.train()

                step_log = dict()
                # ========= train for this epoch ==========
                if cfg.training.freeze_encoder:
                    self.model.obs_encoder.eval()
                    self.model.obs_encoder.requires_grad_(False)

                train_losses = list()
                with tqdm.tqdm(self.train_dataloader, desc=f"Training epoch {self.epoch}",
                        leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
                    for batch_idx, batch in enumerate(tepoch):
                        raw_loss_total = 0
                        expert_loss_dict = {}
                        gating_loss_dict = {}

                        for use_expert_i in range(self.num_experts):
                            # device transfer
                            with accelerator.accumulate(self.model):
                                batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
                                if train_sampling_batch is None:
                                    train_sampling_batch = batch

                                # always use the latest batch
                                train_sampling_batch = batch

                                # compute loss
                                raw_loss, gating_loss, expert_loss = self.model(batch, use_expert_i)
                                raw_loss_total += raw_loss.item()
                                expert_loss_dict[f"expert_loss_{use_expert_i}"] = expert_loss.item()
                                gating_loss_dict[f"gating_loss_{use_expert_i}"] = gating_loss.item()
                                loss = raw_loss / cfg.training.gradient_accumulate_every
                                # loss.backward()
                                accelerator.backward(loss)

                                if accelerator.sync_gradients:
                                    params_to_clip = self.model.parameters()
                                    accelerator.clip_grad_norm_(params_to_clip, 0.1) # 1.0

                                # step optimizer
                                if self.global_step % cfg.training.gradient_accumulate_every == 0:
                                    self.optimizer.step()
                                    self.optimizer.zero_grad()
                                    lr_scheduler.step()

                            # update ema
                            if cfg.training.use_ema:
                                ema.step(accelerator.unwrap_model(self.model))

                            # logging
                            raw_loss_cpu = raw_loss.item()
                            tepoch.set_postfix(loss=expert_loss.item(), gating_loss=gating_loss.item(), refresh=False)
                        
                        train_losses.append(raw_loss_cpu)
                        step_log = {
                            'train_loss': raw_loss_cpu,
                            'global_step': self.global_step,
                            'epoch': self.epoch,
                            'lr': lr_scheduler.get_last_lr()[0]
                        }
                        step_log.update(expert_loss_dict)
                        step_log.update(gating_loss_dict)

                        is_last_batch = (batch_idx == (len(self.train_dataloader)-1))
                        if not is_last_batch:
                            # log of last step is combined with validation and rollout
                            accelerator.log(step_log, step=self.global_step)
                            json_logger.log(step_log)
                            self.global_step += 1

                        if (cfg.training.max_train_steps is not None) \
                            and batch_idx >= (cfg.training.max_train_steps-1):
                            break

                # at the end of each epoch
                # replace train_loss with epoch average
                train_loss = np.mean(train_losses)
                step_log['train_loss'] = train_loss

                # ========= eval for this epoch ==========
                policy = accelerator.unwrap_model(self.model)
                if cfg.training.use_ema:
                    policy = self.ema_model
                policy.eval()

                # run rollout
                if (self.epoch % cfg.training.rollout_every) == 0:
                    runner_log = env_runner.run(policy)
                    # log all
                    step_log.update(runner_log)

                # run validation
                # if (self.epoch % cfg.training.val_every) == 0 and len(val_dataloader) > 0 and accelerator.is_main_process:
                #     with torch.no_grad():
                #         val_losses = list()
                #         with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}", 
                #                 leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
                #             for batch_idx, batch in enumerate(tepoch):
                #                 batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
                #                 loss = self.model(batch)
                #                 val_losses.append(loss)
                #                 if (cfg.training.max_val_steps is not None) \
                #                     and batch_idx >= (cfg.training.max_val_steps-1):
                #                     break
                #         if len(val_losses) > 0:
                #             val_loss = torch.mean(torch.tensor(val_losses)).item()
                #             # log epoch average validation loss
                #             step_log['val_loss'] = val_loss

                def log_action_mse(step_log, category, pred_action, gt_action):
                    B, T, _ = pred_action.shape
                    pred_action = pred_action.view(B, T, -1, 10)
                    gt_action = gt_action.view(B, T, -1, 10)
                    step_log[f'{category}_action_mse_error'] = torch.nn.functional.mse_loss(pred_action, gt_action)
                    step_log[f'{category}_action_mse_error_pos'] = torch.nn.functional.mse_loss(pred_action[..., :3], gt_action[..., :3])
                    step_log[f'{category}_action_mse_error_rot'] = torch.nn.functional.mse_loss(pred_action[..., 3:9], gt_action[..., 3:9])
                    step_log[f'{category}_action_mse_error_width'] = torch.nn.functional.mse_loss(pred_action[..., 9], gt_action[..., 9])

                # # run diffusion sampling on a training batch
                # if (self.epoch % cfg.training.sample_every) == 0 and accelerator.is_main_process:
                #     with torch.no_grad():
                #         # sample trajectory from training set, and evaluate difference
                #         batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
                #         gt_action = batch['action']
                #         pred_action = policy.predict_action(batch['obs'])['action_pred']
                #         log_action_mse(step_log, 'train', pred_action, gt_action)
                #
                #         if len(val_dataloader) > 0:
                #             val_sampling_batch = next(iter(val_dataloader))
                #             batch = dict_apply(val_sampling_batch, lambda x: x.to(device, non_blocking=True))
                #             gt_action = batch['action']
                #             pred_action = policy.predict_action(batch['obs'])['action_pred']
                #             log_action_mse(step_log, 'val', pred_action, gt_action)
                #
                #         del batch
                #         del gt_action
                #         del pred_action
                
                # checkpoint
                if (self.epoch % cfg.training.checkpoint_every) == 0 and accelerator.is_main_process:
                    # unwrap the model to save ckpt
                    model_ddp = self.model
                    self.model = accelerator.unwrap_model(self.model)

                    # checkpointing
                    if cfg.checkpoint.save_last_ckpt:
                        self.save_checkpoint()
                    if cfg.checkpoint.save_last_snapshot:
                        self.save_snapshot()

                    # sanitize metric names
                    metric_dict = dict()
                    for key, value in step_log.items():
                        new_key = key.replace('/', '_')
                        metric_dict[new_key] = value
                    
                    # We can't copy the last checkpoint here
                    # since save_checkpoint uses threads.
                    # therefore at this point the file might have been empty!
                    topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)

                    if topk_ckpt_path is not None:
                        self.save_checkpoint(path=topk_ckpt_path)

                    # recover the DDP model
                    self.model = model_ddp
                # ========= eval end for this epoch ==========
                # end of epoch
                # log of last step is combined with validation and rollout
                accelerator.log(step_log, step=self.global_step)
                json_logger.log(step_log)
                self.global_step += 1
                self.epoch += 1

        accelerator.end_training()

@hydra.main(
    version_base=None,
    config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), 
    config_name=pathlib.Path(__file__).stem)
def main(cfg):
    workspace = TrainDiffusionTransformerTimmWorkspace(cfg)
    workspace.run()

if __name__ == "__main__":
    main()
