import os
import time
import hydra
import swanlab as wandb
from hydra.utils import instantiate
from omegaconf import OmegaConf
from tqdm import tqdm
import torch
import torch.nn as nn
import utils.utils as utils
from utils.logger import Logger
import gc
import random
import numpy as np
from utils import RoboModelWrapper, DataLoaderWithTimeWrapper
from datasets import create_engine, eval_libero
from utils import CosineAnnealingWarmUpRestarts
from utils import get_scheduler
import torch.distributed as dist
import datetime

from torch.utils.data import DataLoader, DistributedSampler
from diffusion_policy.common.pytorch_util import dict_apply

OmegaConf.register_new_resolver("eval", eval, replace=True)
OmegaConf.register_new_resolver("eq", lambda a, b:a == b)
OmegaConf.register_new_resolver("when", lambda cond, a, b:a if cond else b)

@hydra.main(config_path="config", config_name= "train_base", version_base=None)
def main(cfg):
    device = cfg.device
    seed = cfg.seed
    torch.manual_seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    train_cfg = cfg.training

    # dist.init_process_group(backend="nccl")
    dist.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=360000)) # 100 min
    local_rank = int(os.environ["LOCAL_RANK"])  # torchrun
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)

    print(f'local_rank:{local_rank}')
    print(f'device:{device}')

    # create model
    model = instantiate(cfg.algo.policy)
    dtype = torch.float32
    # model = RoboModelWrapper(model)
    model = model.to(device)
    # model.train()
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

    n_parameters = sum(p.numel() for p in model.module.parameters() if p.requires_grad)
    print("model number of parameters: %.2fM" % (n_parameters / 1e6,))

    world_size = dist.get_world_size()
    rank = dist.get_rank()
    print(f'world_size:{world_size}')
    print(f'rank:{rank}')

    if cfg.name == "libero":
        train_loader, agent = instantiate(cfg.task.dataset, world_size=world_size, global_rank=rank)
        train_loader = DataLoaderWithTimeWrapper(train_loader)
    else:
        dataset = hydra.utils.instantiate(cfg.task.dataset)
        sampler = DistributedSampler(dataset, shuffle=cfg.dataloader.shuffle, num_replicas=world_size, rank=rank)
        train_loader = DataLoader(dataset, batch_size=cfg.dataloader.batch_size, sampler=sampler,
                                  pin_memory=cfg.dataloader.pin_memory, drop_last=cfg.dataloader.drop_last)
        normalizer = dataset.get_normalizer()

        # configure validation dataset
        val_dataset = dataset.get_validation_dataset()
        val_sampler = DistributedSampler(val_dataset, shuffle=cfg.val_dataloader.shuffle, num_replicas=world_size,
                                         rank=rank)
        val_dataloader = DataLoader(val_dataset, batch_size=cfg.dataloader.batch_size, sampler=val_sampler,
                                  pin_memory=cfg.dataloader.pin_memory, drop_last=cfg.dataloader.drop_last)

        model.module.set_normalizer(normalizer)

    # build optimizer and lr_scheduler
    train_dataloader_len = len(train_loader)
    num_iters = train_dataloader_len * train_cfg.n_epoch_iters

    optimizer = hydra.utils.instantiate(
        cfg.optimizer, params=model.module.parameters())

    scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_warmup=0.01 * num_iters,
                                              T_max=num_iters, eta_min=1e-5)

    experiment_dir, experiment_name = utils.get_experiment_dir(cfg)
    os.makedirs(experiment_dir, exist_ok=True)

    start_epoch, steps, wandb_id = 0, 0, None

    checkpoint_path = cfg.checkpoint_path

    if checkpoint_path is not None:
        checkpoint_path = utils.get_latest_checkpoint(checkpoint_path)
        print(f'loading from checkpoint {checkpoint_path}')
        state_dict = utils.load_state(checkpoint_path)
        loaded_state_dict = state_dict['model']

        for key in loaded_state_dict:
            if isinstance(loaded_state_dict[key], torch.Tensor):
                loaded_state_dict[key] = loaded_state_dict[key].to(device=device, dtype=dtype)

        # Below line allows loading state dicts with some mismatched parameters
        utils.soft_load_state_dict(model.module, loaded_state_dict)

        # resuming training since we are loading a checkpoint training the same stage
        if cfg.stage == state_dict['stage']:
            print('loading from checkpoint')
            optimizer_state_dict = state_dict['optimizers']
            for param in optimizer_state_dict['state']:
                if isinstance(optimizer_state_dict['state'][param], dict):
                    for k, v in optimizer_state_dict['state'][param].items():
                        if isinstance(v, torch.Tensor):
                            optimizer_state_dict['state'][param][k] = v.to(device)
                elif isinstance(optimizer_state_dict['state'][param], torch.Tensor):
                    optimizer_state_dict['state'][param] = optimizer_state_dict['state'][param].to(device)

            optimizer.load_state_dict(optimizer_state_dict)
            scheduler.load_state_dict(state_dict['schedulers'])
            start_epoch = state_dict['epoch']
            steps = state_dict['steps']
    else:
        print('starting from scratch')

    if dist.get_rank() == 0:
        wandb_run = wandb.init(
            dir=experiment_dir,
            name=experiment_name,
            config=OmegaConf.to_container(cfg, resolve=True),
            id=wandb_id,
            **cfg.logging
        )

    # save batch for sampling
    model.to(device)
    train_sampling_batch = None

    if cfg.training.debug:
        cfg.training.n_epoch_iters = 2
        cfg.training.save_interval = 1
        cfg.training.sample_every = 1

    print('Training...')

    for epoch in range(start_epoch, train_cfg.n_epoch_iters + 1):
        if cfg.name == "libero":
            train_loader.dataloader.sampler.set_epoch(epoch)
        else:
            train_loader.sampler.set_epoch(epoch)
            val_dataloader.sampler.set_epoch(epoch)

        t0 = time.time()
        model.train()
        step_log = dict()
        train_losses = list()

        for ep_iter, batch in enumerate(tqdm(train_loader, disable=not train_cfg.use_tqdm, position=dist.get_rank())):
            if cfg.name == "libero":
                batch = utils.process_inputs(device, dtype, batch)
            else:
                batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))

            if train_sampling_batch is None:
                train_sampling_batch = batch

            loss, info = model.module.compute_loss(batch)
            optimizer.zero_grad()
            loss.backward()

            if train_cfg.grad_clip is not None:
                grad_norm = nn.utils.clip_grad_norm_(
                    model.module.parameters(), train_cfg.grad_clip
                )

            optimizer.step()
            scheduler.step()

            # logging
            raw_loss_cpu = loss.item()
            train_losses.append(raw_loss_cpu)
            step_log = {
                'train_loss': raw_loss_cpu,
                'global_step': steps,
                'epoch': epoch,
                'lr': scheduler.get_last_lr()[0]
            }

            step_log.update(info)

            if train_cfg.grad_clip is not None:
                info.update({
                    "grad_norm": grad_norm.item(),
                })

            if dist.get_rank() == 0 and steps % train_cfg.log_interval == 0:
                wandb_run.log(step_log, step=steps)

            steps += 1

        train_loss = np.mean(train_losses)
        t1 = time.time()
        print(
            f"[info] Epoch: {epoch:3d} | train loss: {train_loss:5.5f} | time: {(t1 - t0) / 60:4.2f}"
        )

        if epoch > 0 and epoch % train_cfg.save_interval == 0:
            if cfg.training.save_all_checkpoints:
                model_checkpoint_name_ep = os.path.join(
                        experiment_dir, f"multitask_model_epoch_{epoch:06d}.pth"
                    )
            else:
                model_checkpoint_name_ep = os.path.join(
                        experiment_dir, f"multitask_model.pth"
                    )

            if dist.get_rank() == 0:
                utils.save_state({
                    'model': model.module,
                    'optimizers': optimizer,
                    'schedulers': scheduler,
                    'epoch': epoch,
                    'stage': cfg.stage,
                    'steps': steps,
                    'experiment_dir': experiment_dir,
                    'experiment_name': experiment_name,
                    'config': OmegaConf.to_container(cfg, resolve=True)
                }, model_checkpoint_name_ep)
            dist.barrier()

        model.eval()
        # run diffusion sampling on a training batch
        if (epoch % train_cfg.sample_every) == 0 and cfg.stage != 0:
            with torch.no_grad():
                # sample trajectory from training set, and evaluate difference
                if cfg.name == "libero":
                    # batch = utils.process_inputs(device, dtype, train_sampling_batch)
                    batch = train_sampling_batch
                    gt_action = batch['cur_actions']
                    obs_dict = batch['cur_images']
                    pred_action, _ = model.module.generate(**batch)
                    result = pred_action.to(device)
                    mse = torch.nn.functional.mse_loss(result, gt_action)
                    step_log['train_action_mse_error'] = mse.item()
                else:
                    batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
                    obs_dict = batch['obs']
                    gt_action = batch['action']

                    result = model.module.predict_action(obs_dict)
                    pred_action = result['action_pred']
                    mse = torch.nn.functional.mse_loss(pred_action, gt_action)
                    step_log['train_action_mse_error'] = mse.item()
                if dist.get_rank() == 0:
                    wandb_run.log(step_log, step=steps)

                del batch
                del obs_dict
                del gt_action
                del result
                del pred_action
                del mse

    dist.barrier()
    dist.destroy_process_group()
    torch.cuda.empty_cache()
    gc.collect()

if __name__ == "__main__":
    main()






