#-*- coding:utf-8 -*-

import sys 
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from diffusion import ConditionalKarrasUnet1D, rand_log_normal, Denoiser, ModelType, MultiModelType
from diffusion.dit import *
from diffusion.transformer import TransformerForDiffusion
from dataset.robomimic_lowdim_dataset import RobomimicReplayLowdimDataset 
from dataset.pusht_dataset import PushTStateDataset
from dataset.tasks import *
from env import PushTEnv

from train.checkpoint_manager import CheckPointManager
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp

from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler

from functools import partial
from copy import deepcopy
from tqdm.auto import tqdm
import numpy as np
import argparse
import random
import torch.nn as nn
import torch
import gdown
import wandb
import os 

CONTROL_TYPE = ControlType.STATE

parser = argparse.ArgumentParser()
# General options
parser.add_argument('-e', '--epochs', type=int, default=10)
parser.add_argument('-b', '--batchsize', type=int, default=256)
parser.add_argument('--diffusion_timesteps', type=int, default=40) # Different from training
parser.add_argument('--ema_power', type=float, default=0.75)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--save-per-epochs', type=int, default=50)
parser.add_argument('-m', '--model_type', type=str, default="CNN")

# Diffusino Policy options
parser.add_argument('--task_type', type=str, default="PUSHT")
parser.add_argument('--task_tag', type=str, default="")
parser.add_argument('--save-last-checkpoints', type=int, default=10)
parser.add_argument('--warmup_steps', type=int, default=500)

# Karras (EDM) options
parser.add_argument('--sigma_data', type=float, default=0.5)
parser.add_argument('--sigma_sample_density_mean', type=float, default=-1.2)
parser.add_argument('--sigma_sample_density_std', type=float, default=1.2)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--resume-model-path', type=str, default="")
opt = parser.parse_args()

torch.manual_seed(opt.seed)
np.random.seed(opt.seed)
random.seed(opt.seed)

if opt.task_tag == "":
    task_tag = TaskTags.NONE
elif opt.task_tag == "PH":
    task_tag = TaskTags.PH 
elif opt.task_tag == "MH":
    task_tag = TaskTags.MH
else:
    raise NotImplementedError(f"Task Tag {opt.task_tag} Not implemented")

if opt.task_type == 'PUSHT':
    task_type = TaskTypes.PUSHT
    task = PushT(ctype=CONTROL_TYPE)
elif opt.task_type == 'LIFT':
    task_type = TaskTypes.LIFT
    task = Lift(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'CAN':
    task_type = TaskTypes.CAN
    task = Can(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'SQUARE':
    task_type = TaskTypes.SQUARE
    task = Square(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'TRANSPORT':
    task_type = TaskTypes.TRANSPORT
    task = Transport(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'TOOLHANG':
    task_type = TaskTypes.TOOLHANG
    task = ToolHang(ctype=CONTROL_TYPE)
else:
    raise NotImplementedError(f"Task {opt.task_type} Not implemented")

device = 'cuda'

def create_env():
    # 0. create env object
    env = PushTEnv()

    # 1. seed env for initial state. 
    # Seed 0-200 are used for the demonstration dataset.
    env.seed(1000)

    # 2. must reset before use
    obs = env.reset()

    # 3. 2D positional action space [0,512]
    action = env.action_space.sample()

    # 4. Standard gym step method
    obs, reward, done, info = env.step(action)
    return env

def create_model(model_type:MultiModelType = MultiModelType.CNN):
    action_dim = task.action_dim
    dim = task.obs_dim

    # create network object
    if model_type == MultiModelType.CNN:
        inner_model = ConditionalKarrasUnet1D(
            input_dim=action_dim,
            global_cond_dim=dim*task.obs_horizon, # ResNet18 feature map size + Low
            diffusion_step_embed_dim=256,
            down_dims=[256,512,1024]
        )
    elif model_type == MultiModelType.MINGPT:
        inner_model = TransformerForDiffusion(
            input_dim=action_dim,
            output_dim=action_dim,
            horizon=task.pred_horizon,
            n_obs_steps=task.obs_horizon,
            cond_dim=dim, # ResNet18 feature map size + Low
            causal_attn=True,
            n_cond_layers=4
        )
    elif model_type == MultiModelType.DiT_S:
        inner_model = DiT_S_2(
            input_size=action_dim,
            horizon=task.pred_horizon, 
            obs_horizon=task.obs_horizon, 
            obs_dim=dim, 
        )
    elif model_type == MultiModelType.DiT_B:
        inner_model = DiT_B_2(
            input_size=action_dim,
            horizon=task.pred_horizon, 
            obs_horizon=task.obs_horizon, 
            obs_dim=dim, 
        )
    elif model_type == MultiModelType.DiT_L:
        inner_model = DiT_L_2(
            input_size=action_dim,
            horizon=task.pred_horizon, 
            obs_horizon=task.obs_horizon, 
            obs_dim=dim, 
        )
    elif model_type == MultiModelType.DiT_XL:
        inner_model = DiT_L_2(
            input_size=action_dim,
            horizon=task.pred_horizon, 
            obs_horizon=task.obs_horizon, 
            obs_dim=dim, 
        )
    else:
        raise NotImplementedError
    return inner_model

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # プロセスグループの初期化
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size, run):
    # parameters
    pred_horizon = task.pred_horizon
    obs_horizon = task.obs_horizon
    action_horizon = task.action_horizon
    #|o|o|                             observations: 2
    #| |a|a|a|a|a|a|a|a|               actions executed: 8
    #|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16
    setup(rank, world_size)
    device = rank

    # create dataset from file
    if task_type == TaskTypes.PUSHT:
        # download demonstration data from Google Drive
        dataset_path = "./data/pusht_cchi_v7_replay.zarr.zip"
        if not os.path.isfile(dataset_path):
            id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
            gdown.download(id=id, output=dataset_path, quiet=False)
        dataset = PushTStateDataset(
            dataset_path=dataset_path,
            pred_horizon=pred_horizon,
            obs_horizon=obs_horizon,
            action_horizon=action_horizon
        )
    else:
        dataset = RobomimicReplayLowdimDataset(
            dataset_path=task.dataset_path,
            horizon=task.pred_horizon,
            obs_keys=task.obs_keys,
            abs_action=True,
            pad_before=1,
            pad_after=7,
        )
    
    # create dataloader
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
        seed=opt.seed
    )
    
    # create dataloader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=int(opt.batchsize // world_size),
        num_workers=1,
        #shuffle=True,
        # accelerate cpu-gpu transfer
        pin_memory=True, 
        # don't kill worker process afte each epoch
        persistent_workers=True,
        sampler=sampler
    )

    # visualize data in batch
    batch = next(iter(dataloader))
    print("batch['obs'].shape:", batch['obs'].shape)
    print("batch['action'].shape", batch['action'].shape)

    # device transfer
    if opt.model_type == 'CNN':
        model_type = MultiModelType.CNN 
    elif opt.model_type == 'minGPT':
        model_type = MultiModelType.MINGPT
    elif opt.model_type == 'DiT_S':
        model_type = MultiModelType.DiT_S
    elif opt.model_type == 'DiT_B':
        model_type = MultiModelType.DiT_B
    elif opt.model_type == 'DiT_L':
        model_type = MultiModelType.DiT_L
    elif opt.model_type == 'DiT_XL':
        model_type = MultiModelType.DiT_XL
    else:
        raise NotImplementedError
    
    print("Model Type:", model_type)

    check_point_manager = CheckPointManager(
        epochs=opt.epochs, 
        save_last_epochs=opt.save_last_checkpoints,
        save_per_epochs=opt.save_per_epochs,
        task_name=task_type.value,
        task_tag=task_tag.value,
        control_type=CONTROL_TYPE.value,
        model_type=model_type.value,
        base_name=f"EDM",
        ignore_make_folder=(opt.resume_model_path != ""),
        rank=rank
    )

    inner_model = create_model(model_type=model_type)
    _ = inner_model.to(device)
    noise_pred_net = Denoiser(inner_model=inner_model, sigma_data=opt.sigma_data)
    ddp_nets = DDP(noise_pred_net, device_ids=[rank])

    sample_density = partial(rand_log_normal, loc=opt.sigma_sample_density_mean, scale=opt.sigma_sample_density_std)

    ema = EMAModel(
        model=ddp_nets.module.inner_model,
        power=opt.ema_power
    )

    optimizer = torch.optim.AdamW(
        params=ddp_nets.parameters(), 
        lr=opt.lr, weight_decay=1e-6
    )

    # Cosine LR schedule with linear warmup
    lr_scheduler = get_scheduler(
        name='cosine',
        optimizer=optimizer,
        num_warmup_steps=opt.warmup_steps,
        num_training_steps=len(dataloader) * opt.epochs
    )
    
    resume_epoch = 0
    wandb_id = run.id

    # Resume Configuration
    if opt.resume_model_path != "":
        checkpoint = torch.load(opt.resume_model_path)
        ddp_nets.module.load_state_dict(checkpoint['model'])
        print("Checkpoint Model Loaded")

        resume_epoch = checkpoint['epoch']
        print(f"Start from Epoch {resume_epoch}")

        check_point_manager.folder_path = checkpoint['save_folder_path']
        print("Working Directory:", check_point_manager.folder_path)
        
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        print("LR scheduler loaded")
        
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("Optimizer loaded")

    with tqdm(range(resume_epoch, opt.epochs), desc='Epoch') as tglobal:
        # epoch loop
        for epoch_idx in tglobal:
            epoch_loss = list()
            # batch loop
            with tqdm(dataloader, desc='Batch', leave=False) as tepoch:
                for nbatch in tepoch:
                    if task_type == TaskTypes.PUSHT:
                        nobs = nbatch['obs'].to(device)
                        naction = nbatch['action'].to(device)
                        B = nobs.shape[0]

                        # observation as FiLM conditioning
                        # (B, obs_horizon, obs_dim)
                        obs_cond = nobs[:,:obs_horizon,:]
                        # (B, obs_horizon * obs_dim)
                    else:
                        nbatch = dataset.normalizer.normalize(nbatch)
                        nobs = nbatch['obs'].to(device)
                        naction = nbatch['action'].to(device)
                        B = nobs.shape[0]

                        # observation as FiLM conditioning
                        # (B, obs_horizon, obs_dim)
                        obs_cond = nobs[:,:obs_horizon,:]

                    if model_type == MultiModelType.CNN:
                        obs_cond = obs_cond.flatten(start_dim=1)

                    # sample noise to add to actions
                    noise = torch.randn(naction.shape, device=device)

                    # sample a diffusion iteration for each data point
                    sigmas = sample_density([B], device=device)

                    # if model_type == ModelType.TRANSFORMER:
                    #     sigmas = sigmas.log() / 4
                        
                    # # L2 loss
                    loss = ddp_nets.module.loss(naction, noise, sigmas, global_cond=obs_cond)
                    loss = loss.mean()

                    # optimize
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()
                    # step lr scheduler every batch
                    # this is different from standard pytorch behavior
                    lr_scheduler.step()

                    # update Exponential Moving Average of the model weights
                    ema.step(inner_model)

                    # logging
                    loss_cpu = loss.item()
                    epoch_loss.append(loss_cpu)
                    tepoch.set_postfix(loss=loss_cpu)
            if rank == 0:
                run.log({"train-loss": np.mean(epoch_loss)})
            tglobal.set_postfix(loss=np.mean(epoch_loss))   
            if rank == 0:
                check_point_manager.update(epoch_idx, ddp_nets.module)

            if (epoch_idx % world_size == 0) and rank == 0:
                check_point_manager.save_checkpoint(ddp_nets.module, epoch_idx, optimizer, lr_scheduler, wandb_id)
                wandb.save(check_point_manager.folder_path)
        check_point_manager.save_training_end(ddp_nets.module)
        # torch.save(noise_pred_net.state_dict(), f'./weights/t-push-edm-diffusion-{opt.model_type}-epoch{epoch_idx+1}'+'.pt')
        cleanup()

if __name__ == '__main__':

    if opt.resume_model_path != "":
        wandb_id = opt.resume_model_path.split("/")[-1].split("_")[-1].replace(".pth", "")
        run = wandb.init(project = 'diffusion_policy_edm_state_conditioned', resume=True, id=wandb_id)
    else:
        run = wandb.init(project = 'diffusion_policy_edm_state_conditioned')
    config = run.config
    config.epochs = opt.epochs
    config.batchsize = opt.batchsize
    config.learning_rate = opt.lr 
    config.action_horizon = task.action_horizon
    config.obs_horizon = task.obs_horizon
    config.diffusion_timesteps = opt.diffusion_timesteps
    config.model_type = opt.model_type

    config.sigma_data = opt.sigma_data
    config.sigma_sample_density_mean = opt.sigma_sample_density_mean
    config.sigma_sample_density_std = opt.sigma_sample_density_std
    config.task_name = task_type.value + "-" + task_tag.value + "-" + CONTROL_TYPE.value
    # train()

    world_size = torch.cuda.device_count()
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    mp.spawn(train, args=(world_size, run), nprocs=world_size, join=True)