#-*- coding:utf-8 -*-

import sys 
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from consistency_policy.utils import append_dims
from dataset.robomimic_lowdim_dataset import RobomimicReplayLowdimDataset 
from dataset.pusht_dataset import PushTStateDataset
from dataset.tasks import *
from env import PushTEnv

from consistency_policy.ctm_unet import CTMConditionalUnet1D
from consistency_policy.diffusion import CTM_Scheduler, Huber_Loss
from diffusion import ConditionalKarrasUnet1D, Denoiser

from train.checkpoint_manager import CheckPointManager
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from diffusers.optimization import get_scheduler
import torch.distributed as dist
import torch.multiprocessing as mp

from typing import List
from functools import partial
from copy import deepcopy
from tqdm.auto import tqdm
import numpy as np
import argparse
import random
import torch.nn as nn
import torch
import gdown
import wandb
import copy
import os 

CONTROL_TYPE = ControlType.STATE

parser = argparse.ArgumentParser()
# General options
parser.add_argument('-e', '--epochs', type=int, default=10)
parser.add_argument('-b', '--batchsize', type=int, default=256)
parser.add_argument('--diffusion_timesteps', type=int, default=40) # Different from training
parser.add_argument('--ema_power', type=float, default=0.9)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--save-per-epochs', type=int, default=50)
parser.add_argument('-m', '--model_type', type=str, default="CNN")

# Diffusino Policy options
parser.add_argument('--task_type', type=str, default="PUSHT")
parser.add_argument('--task_tag', type=str, default="")
parser.add_argument('--save-last-checkpoints', type=int, default=1)
parser.add_argument('--warmup_steps', type=int, default=500)

# Karras (EDM) options
parser.add_argument('--teacher_path', type=str)
parser.add_argument('--sigma_data', type=float, default=0.5)
parser.add_argument('--sigma_min', type=float, default=0.002)
parser.add_argument('--sigma_max', type=float, default=80.0)
parser.add_argument('--rho', type=float, default=7)

# CTM options
parser.add_argument('--dsm_scheduler', type=str, default='karras')
parser.add_argument('--dsm_weights', type=float, default=1.0)
parser.add_argument('--ctm_weights', type=float, default=1.0)
parser.add_argument('--huber_delta', type=float, default=0.0068305197)
parser.add_argument('--sigma_sample_density_mean', type=float, default=-1.2)
parser.add_argument('--sigma_sample_density_std', type=float, default=1.2)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--resume-model-path', type=str, default="")
opt = parser.parse_args()

torch.manual_seed(opt.seed)
np.random.seed(opt.seed)
random.seed(opt.seed)

if opt.task_tag == "":
    task_tag = TaskTags.NONE
elif opt.task_tag == "PH":
    task_tag = TaskTags.PH 
elif opt.task_tag == "MH":
    task_tag = TaskTags.MH
else:
    raise NotImplementedError(f"Task Tag {opt.task_tag} Not implemented")

if opt.task_type == 'PUSHT':
    task_type = TaskTypes.PUSHT
    task = PushT(ctype=CONTROL_TYPE)
elif opt.task_type == 'LIFT':
    task_type = TaskTypes.LIFT
    task = Lift(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'CAN':
    task_type = TaskTypes.CAN
    task = Can(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'SQUARE':
    task_type = TaskTypes.SQUARE
    task = Square(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'TRANSPORT':
    task_type = TaskTypes.TRANSPORT
    task = Transport(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'TOOLHANG':
    task_type = TaskTypes.TOOLHANG
    task = ToolHang(ctype=CONTROL_TYPE)
else:
    raise NotImplementedError(f"Task {opt.task_type} Not implemented")

device = 'cuda'

def create_env():
    # 0. create env object
    env = PushTEnv()

    # 1. seed env for initial state. 
    # Seed 0-200 are used for the demonstration dataset.
    env.seed(1000)

    # 2. must reset before use
    obs = env.reset()

    # 3. 2D positional action space [0,512]
    action = env.action_space.sample()

    # 4. Standard gym step method
    obs, reward, done, info = env.step(action)
    return env

def create_model(
        diffusion_step_embed_dim:int = 256,
        down_dims:List[int]=[256,512,1024],
        kernel_size:int=3,
        n_groups:int=8,
        cond_predict_scale:bool=False,
        dropout_rate:float=.0,
    ):
    input_dim = task.action_dim
    global_cond_dim = task.obs_dim * task.obs_horizon
    model = CTMConditionalUnet1D(
            input_dim=input_dim,
            local_cond_dim=None,
            global_cond_dim=global_cond_dim,
            diffusion_step_embed_dim=diffusion_step_embed_dim,
            down_dims=down_dims,
            kernel_size=kernel_size,
            n_groups=n_groups,
            cond_predict_scale=cond_predict_scale,
            dropout_rate=dropout_rate,
    )
    return model

def create_teacher_model():
    action_dim = task.action_dim
    dim = task.obs_dim
    inner_model = ConditionalKarrasUnet1D(
        input_dim=action_dim,
        global_cond_dim=dim*task.obs_horizon, # ResNet18 feature map size + Low
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024]
    )
    model = Denoiser(inner_model=inner_model, sigma_data=opt.sigma_data)
    return model

@torch.no_grad()
def ema_update(model, ema, ema_decay:float=opt.ema_power):
    param = [p.data for p in model.parameters()]
    param_ema = [p.data for p in ema.parameters()]

    ema_decay = ema_decay

    torch._foreach_mul_(param_ema, ema_decay)
    torch._foreach_add_(param_ema, param, alpha=1 - ema_decay)

@torch.no_grad()
def teacher_heun_solver(teacher_model, samples, t, next_t, global_cond, clamp = False):
    dims = samples.ndim
    y = samples
    step = append_dims((next_t - t), dims)
    extra_args={'global_cond': global_cond}
        
    # denoisedy = self.calc_out(model, y, t, clamp = clamp)
    denoisedy = teacher_model(y, t, **extra_args)
    if clamp:
        denoisedy = denoisedy.clamp(-1.0, 1.0) #this should only happen at inference time
    dy = (y - denoisedy) / append_dims(t, dims)

    y_next = samples + step * dy 

    # denoisedy_next = self.calc_out(model, y_next, next_t, clamp = clamp)
    denoisedy_next = teacher_model(y_next, next_t, **extra_args)
    if clamp:
        denoisedy_next = denoisedy_next.clamp(-1.0, 1.0) #this should only happen at inference time
    dy_next = (y_next - denoisedy_next) / append_dims(next_t, dims)

    y_next = samples + step * (dy + dy_next) / 2
    return y_next

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # プロセスグループの初期化
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size, run):
    # parameters
    pred_horizon = task.pred_horizon
    obs_horizon = task.obs_horizon
    action_horizon = task.action_horizon
    #|o|o|                             observations: 2
    #| |a|a|a|a|a|a|a|a|               actions executed: 8
    #|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16
    setup(rank, world_size)
    device = rank


    # create dataset from file
    if task_type == TaskTypes.PUSHT:
        # download demonstration data from Google Drive
        dataset_path = "./data/pusht_cchi_v7_replay.zarr.zip"
        if not os.path.isfile(dataset_path):
            id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
            gdown.download(id=id, output=dataset_path, quiet=False)
        dataset = PushTStateDataset(
            dataset_path=dataset_path,
            pred_horizon=pred_horizon,
            obs_horizon=obs_horizon,
            action_horizon=action_horizon
        )
    else:
        dataset = RobomimicReplayLowdimDataset(
            dataset_path=task.dataset_path,
            horizon=task.pred_horizon,
            obs_keys=task.obs_keys,
            abs_action=True,
            pad_before=1,
            pad_after=7,
        )
    
    # create dataloader
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
        seed=opt.seed
    )
    
    # create dataloader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=int(opt.batchsize // world_size),
        num_workers=1,
        #shuffle=True,
        # accelerate cpu-gpu transfer
        pin_memory=True, 
        # don't kill worker process afte each epoch
        persistent_workers=True,
        sampler=sampler
    )

    # visualize data in batch
    batch = next(iter(dataloader))
    print("batch['obs'].shape:", batch['obs'].shape)
    print("batch['action'].shape", batch['action'].shape)


    check_point_manager = CheckPointManager(
        epochs=opt.epochs, 
        save_last_epochs=opt.save_last_checkpoints,
        save_per_epochs=opt.save_per_epochs,
        task_name=task_type.value,
        task_tag=task_tag.value,
        control_type=CONTROL_TYPE.value,
        model_type="CNN",
        base_name=f"CTM",
        ignore_make_folder=(opt.resume_model_path != ""),
        rank=rank
    )

    model = create_model()
    _ = model.to(device)
    ddp_nets = DDP(model, device_ids=[device])

    def _forward(_model,
            sample: torch.Tensor, 
            timestep: torch.Tensor,
            stop_time: torch.Tensor,
            local_cond=None, global_cond=None, clamp=False):
        denoise = lambda x, t, s: _model(x, t, s, local_cond=local_cond, global_cond=global_cond)
        return noise_scheduler.CTM_calc_out(denoise, sample, timestep, stop_time, clamp=clamp)

    noise_scheduler = CTM_Scheduler(
        time_min=opt.sigma_min,
        time_max=opt.sigma_max,
        rho=opt.rho,
        bins=opt.diffusion_timesteps,
        solver='heun',
        scaling='boundary',
        use_c_in=True,
        data_std=opt.sigma_data,
        time_sampler='ctm',
        clamp=True,
        ode_steps_max=1,
    )
    delta = opt.huber_delta
    dsm_scheduler = opt.dsm_scheduler
    ctm_weights = opt.ctm_weights
    dsm_weights = opt.dsm_weights

    teacher_model = create_teacher_model()

    state_dict = torch.load(opt.teacher_path, map_location='cuda')
    teacher_model.load_state_dict(state_dict)
    teacher_model.eval()
    teacher_model.requires_grad_(False)
    teacher_model.to(device)
    print("Using teacher: ", opt.teacher_path)

    ema = copy.deepcopy(ddp_nets.module)
    ema.requires_grad_(False)

    optimizer = torch.optim.AdamW(
        params=ddp_nets.parameters(), 
        lr=opt.lr, weight_decay=1e-6
    )

    # Cosine LR schedule with linear warmup
    lr_scheduler = get_scheduler(
        name='cosine',
        optimizer=optimizer,
        num_warmup_steps=opt.warmup_steps,
        num_training_steps=len(dataloader) * opt.epochs
    )

    resume_epoch = 0
    wandb_id = run.id

    # Resume Configuration
    if opt.resume_model_path != "":
        dist.barrier()
        map_location = {'cuda:%d' % 0: 'cuda:%d' % device}
        checkpoint = torch.load(opt.resume_model_path, map_location=map_location)
        ddp_nets.module.load_state_dict(checkpoint['model'])
        print("Checkpoint Model Loaded")

        resume_epoch = checkpoint['epoch']
        print(f"Start from Epoch {resume_epoch}")

        check_point_manager.folder_path = checkpoint['save_folder_path']
        print("Working Directory:", check_point_manager.folder_path)
        
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        print("LR scheduler loaded")
        
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("Optimizer loaded")

    with tqdm(range(resume_epoch, opt.epochs), desc='Epoch') as tglobal:
        # epoch loop
        for epoch_idx in tglobal:
            epoch_loss = list()
            # batch loop
            with tqdm(dataloader, desc='Batch', leave=False) as tepoch:
                for nbatch in tepoch:
                    if task_type == TaskTypes.PUSHT:
                        nobs = nbatch['obs'].to(device)
                        naction = nbatch['action'].to(device)
                        B = nobs.shape[0]

                        # observation as FiLM conditioning
                        # (B, obs_horizon, obs_dim)
                        obs_cond = nobs[:,:obs_horizon,:]
                        # (B, obs_horizon * obs_dim)
                    else:
                        nbatch = dataset.normalizer.normalize(nbatch)
                        nobs = nbatch['obs'].to(device)
                        naction = nbatch['action'].to(device)
                        B = nobs.shape[0]

                        # observation as FiLM conditioning
                        # (B, obs_horizon, obs_dim)
                        obs_cond = nobs[:,:obs_horizon,:]

                    obs_cond = obs_cond.flatten(start_dim=1)

                    # denoise = lambda x, t, s: ddp_nets.module(x, t, s, local_cond=None, global_cond=obs_cond)
                    t, s, u = noise_scheduler.sample_times(naction, time_sampler="ctm")
                    times = noise_scheduler.timesteps_to_times(t)
                    stops = noise_scheduler.timesteps_to_times(s)
                    u_times = noise_scheduler.timesteps_to_times(u)

                    noise_traj = noise_scheduler.add_noise(naction, times)

                    # denoise = lambda x, t: teacher_model.inner_model(x, t, global_cond=obs_cond)
                    u_noise_traj = noise_traj

                    for d in range(noise_scheduler.ode_steps_max):
                        ct = torch.stack([(t_i + d).clamp(int(t_i.item()), int(u_i.item())) for t_i, u_i in zip(t, u)])
                        nt = torch.stack([(t_i + d + 1).clamp(int(t_i.item()), int(u_i.item())) for t_i, u_i in zip(t, u)])

                        current_times = noise_scheduler.timesteps_to_times(ct)
                        next_times = noise_scheduler.timesteps_to_times(nt)

                        # u_noise_traj = noise_scheduler.step(denoise, u_noise_traj, current_times, next_times, clamp=False)
                        u_noise_traj = teacher_heun_solver(teacher_model, u_noise_traj, current_times, next_times, global_cond=obs_cond, clamp=False)
                    
                    # t -> s
                    pred = _forward(ddp_nets.module, noise_traj, times, stops, 
                                            local_cond=None, global_cond=obs_cond)

                    # u -> s
                    target = _forward(ema, u_noise_traj, u_times, stops,
                                        local_cond=None, global_cond=obs_cond)

                    # now we take both back to 0
                    start = torch.tensor([noise_scheduler.time_min], device = device).expand(times.shape)

                    pred = _forward(ema, pred, stops, start, 
                                        local_cond=None, global_cond=obs_cond)
                    
                    target = _forward(ema, target, stops, start,
                                        local_cond=None, global_cond=obs_cond)


                    loss_ctm = Huber_Loss(pred, target, delta = delta, weights=None)

                    times, _ = noise_scheduler.sample_times(naction, time_sampler='ctm_dsm')
                    weights = noise_scheduler.get_weights(times, None, dsm_scheduler)
                    # Add noise to the clean images according to the noise magnitude at each timestep
                    # (this is the forward diffusion process)
                    noisy_trajectory = noise_scheduler.add_noise(naction, times)
                    
                    # Predict the initial state
                    stop = torch.tensor([noise_scheduler.time_min], device = device).expand(times.shape)
                    pred = _forward(ddp_nets.module, noisy_trajectory, times, stop,
                                            local_cond=None, global_cond=obs_cond, clamp=False)
                    
                    target = naction
                    loss_dsm = Huber_Loss(pred, target, delta = delta, weights=weights)
                        
                    loss = ctm_weights * loss_ctm + dsm_weights * loss_dsm
                    loss = loss.mean()

                    # optimize
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()
                    # step lr scheduler every batch
                    # this is different from standard pytorch behavior
                    lr_scheduler.step()

                    # update Exponential Moving Average of the model weights
                    ema_update(ddp_nets.module, ema)

                    # logging
                    loss_cpu = loss.item()
                    epoch_loss.append(loss_cpu)
                    tepoch.set_postfix(loss=loss_cpu)
            if rank == 0:
                run.log({"train-loss": np.mean(epoch_loss)})
            tglobal.set_postfix(loss=np.mean(epoch_loss))   
            if rank == 0:
                check_point_manager.update(epoch_idx, ddp_nets.module)

            if (epoch_idx % world_size == 0) and rank == 0:
                check_point_manager.save_checkpoint(ddp_nets.module, epoch_idx, optimizer, lr_scheduler, wandb_id)
                wandb.save(check_point_manager.folder_path)
        check_point_manager.save_training_end(ddp_nets.module)

    cleanup()

if __name__ == '__main__':
    if opt.resume_model_path != "":
        wandb_id = opt.resume_model_path.split("/")[-1].split("_")[-1].replace(".pth", "")
        run = wandb.init(project = 'consistency_policy_ctm_state_conditioned', resume=True, id=wandb_id)
    else:
        run = wandb.init(project = 'consistency_policy_ctm_state_conditioned')
    config = run.config
    config.epochs = opt.epochs
    config.batchsize = opt.batchsize
    config.learning_rate = opt.lr 
    config.action_horizon = task.action_horizon
    config.obs_horizon = task.obs_horizon
    config.diffusion_timesteps = opt.diffusion_timesteps
    config.model_type = opt.model_type

    config.sigma_data = opt.sigma_data
    config.sigma_sample_density_mean = opt.sigma_sample_density_mean
    config.sigma_sample_density_std = opt.sigma_sample_density_std
    config.task_name = task_type.value + "-" + task_tag.value + "-" + CONTROL_TYPE.value
    # train()

    world_size = torch.cuda.device_count()
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    mp.spawn(train, args=(world_size, run), nprocs=world_size, join=True)

