import os
import sys
import os.path as osp
# append parent path to environment
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
import json
import math
import copy
import torch
import random
import logging
import argparse
import importlib
import itertools
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.distributed as dist
import torch.nn.functional as F
import torch.multiprocessing as mp
import torchvision.utils as tvutils
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from io import BytesIO
from PIL import Image
from easydict import EasyDict
from functools import partial
from importlib import reload
from torch.utils.data.sampler import Sampler
from collections import defaultdict
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DistributedDataParallel
from torchvision.transforms.functional import InterpolationMode

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='config .py file.', default=None)
args, unknown = parser.parse_known_args()
print(f'unknown args: {unknown}')
print (args.config)
cfg = getattr(importlib.import_module(args.config), 'cfg')

from diffusion_tools.diffusion.diffusion import GaussianDiffusion, beta_schedule
if cfg.get('unet_mod') is None:
    cfg.unet_mod = 'diffusion_tools.module.networks.unet'
UNet = getattr(importlib.import_module(cfg.unet_mod), 'UNet')

def enforce_zero_terminal_snr(betas):
    alphas = 1 - betas
    alphas_bar = alphas.cumprod(0)
    alphas_bar_sqrt = alphas_bar.sqrt()

    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
    alphas_bar_sqrt -= alphas_bar_sqrt_T
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    alphas_bar = alphas_bar_sqrt ** 2
    alphas = alphas_bar[1:] / alphas_bar[:-1]
    alphas = torch.cat([alphas_bar[0:1], alphas])
    betas = 1 - alphas
    return betas

def adjust_learning_rate(cfg, optimizer, eiters):
    if eiters < cfg.warmup_steps:
        lr = (cfg.lr - cfg.minum_lr) * float(eiters) / cfg.warmup_steps + cfg.minum_lr
    else:
        if cfg.anneal_lr:
            progress = (eiters - cfg.warmup_steps) / float(cfg.num_steps - cfg.warmup_steps)
            progress = np.clip(progress, 0.0, 1.0)
            if cfg.decay_type == 'linear':
                lr = cfg.minum_lr + (cfg.lr - cfg.minum_lr) * (1.0 - progress)
            elif cfg.decay_type == 'cosine':
                lr = cfg.minum_lr + (cfg.lr - cfg.minum_lr) * 0.5 * (1. + math.cos(math.pi * progress))
            else:
                raise ValueError('Unknown lr type {}'.format(cfg.decay_type))
        else:
            lr = cfg.lr

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def ceil_divide(a, b):
    return int(math.ceil(a / b))

def all_gather(tensor, uniform_size=True, **kwargs):
    if dist.get_world_size() == 1:
        return [tensor]
    
    if uniform_size:
        tensor_list = [torch.empty_like(tensor) for _ in range(get_world_size())]
        dist.all_gather(tensor_list, tensor, **kwargs)
        return tensor_list

def generalized_all_gather(data, group=None):
    if dist.get_world_size() == 1:
        return [data]
    if group is None:
        group = get_global_gloo_group()
    if dist.get_world_size(group) == 1:
        return [data]
    
    tensor = _serialize_to_tensor(data, group)
    size_list, tensor = _pad_to_largest_tensor(tensor, group)
    max_size = max(size_list)

    # receiving tensors from all ranks
    tensor_list = [torch.empty(
        (max_size, ), dtype=torch.uint8, device=tensor.device)
        for _ in size_list]
    dist.all_gather(tensor_list, tensor, group=group)

    data_list = []
    for size, tensor in zip(size_list, tensor_list):
        buffer = tensor.cpu().numpy().tobytes()[:size]
        data_list.append(pickle.loads(buffer))
    return data_list

class BatchSampler(Sampler):
    def __init__(self, batch_size, dataset_size=0, list_file=None, 
                 num_replicas=None, rank=None, shuffle=False, seed=None):
        assert dataset_size > 0 or list_file is not None
        self.list_file = list_file
        if list_file is not None:
            self.items = [u.split(',', 1) for u in open(list_file).read().strip().split('\n')]
            self.dataset_size = len(self.items)
        else:
            self.dataset_size = dataset_size
            self.items = None
        self.batch_size = batch_size
        self.num_replicas = num_replicas or get_world_size()
        self.rank = rank or dist.get_rank()
        self.shuffle = shuffle
        self.seed = seed or shared_random_seed()
        self.rng = np.random.default_rng(self.seed + self.rank)
        self.batches_per_rank = ceil_divide(self.dataset_size, self.num_replicas * self.batch_size)
        self.samples_per_rank = self.batches_per_rank * self.batch_size

        # rank indices
        indices = self.rng.permutation(self.samples_per_rank) if shuffle else np.arange(self.samples_per_rank)
        indices = indices * self.num_replicas + self.rank
        indices = indices[indices < self.dataset_size]
        self.indices = indices
    
    def __iter__(self):
        start = 0
        while True:
            batch = [self.indices[i % len(self.indices)] for i in range(start, start + self.batch_size)]
            if self.items is not None:
                batch = [self.items[i] for i in batch]
            if self.shuffle and (start + self.batch_size) >= len(self.indices):
                self.rng.shuffle(self.indices)
            start = (start + self.batch_size) % len(self.indices)
            yield batch

class Compose(object):
    r"""Slicable version of T.Compose.
    """
    def __init__(self, transforms):
        self.transforms = transforms
    
    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img
    
    def __getitem__(self, index):
        if isinstance(index, int):
            return self.transforms[index]
        elif isinstance(index, slice):
            return Compose(self.transforms[index])
        else:
            raise TypeError(f'invalid argument {index}')

class RandomCrop(object):

    def __init__(self, size=256, min_crop=0.5, max_crop=1.0):
        self.size = size
        self.min_crop = min_crop
        self.max_crop = max_crop
    
    def __call__(self, img):
        # random crop
        w, h = img.size
        ratio = random.uniform(self.min_crop, self.max_crop)
        side = max(int(ratio * min(w, h)), min(self.size, w, h))
        x1 = random.randint(0, w - side)
        y1 = random.randint(0, h - side)
        img = img.crop((x1, y1, x1 + side, y1 + side))

        # resize
        if img.width != self.size or img.height != self.size:
            img = img.resize((self.size, self.size), Image.LANCZOS)
        return img

class ImageFolder(Dataset):

    def __init__(self, root_dir, list_file, transforms, rank, world_size):
        self.root_dir = root_dir
        self.list_file = list_file
        self.transforms = transforms
        self.rank = rank
        self.world_size = world_size

        with open(self.list_file, mode='r') as f:
            self.image_paths = f.read().splitlines()

    def __getitem__(self, index):
        try:
            # read image
            image_path = osp.join(self.root_dir, self.image_paths[index])
            with open(image_path, mode='rb') as f:
                img_path = f.read()
            img = Image.open(BytesIO(img_path))
            if img.mode != 'RGB':
                img = img.convert('RGB')
            img = self.transforms[:-2](img)
            img = np.array(img).astype(np.uint8)
            img = (img/127.5 - 1.0).astype(np.float32)
            
            # output
            self._placeholder = (img)
        except Exception as e:
            print(f'Processing {self.image_paths[index]} failed with error {e}', flush=True)
        return self._placeholder

    def __len__(self):
        return len(self.image_paths)

def main(**kwargs):
    cfg.update(**kwargs)
    cfg.pmi_rank = int(os.environ['RANK'])
    cfg.pmi_world_size = int(os.environ['WORLD_SIZE'])
    cfg.gpus_per_machine = torch.cuda.device_count()
    cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine
    if cfg.world_size == 1:
        worker(0, cfg)
    else:
        mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, ))
    return cfg

def worker(gpu, cfg):
    cfg.gpu = gpu
    cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu

    # init distributed processes
    torch.cuda.set_device(gpu)
    torch.backends.cudnn.benchmark = True
    dist.init_process_group(backend='nccl', world_size=cfg.world_size, rank=cfg.rank)

    if not osp.exists(cfg.log_dir):
        os.makedirs(cfg.log_dir)

    # logging
    if not osp.exists('logs'):
        os.makedirs('logs')
    cfg.log_dir = generalized_all_gather(cfg.log_dir)[0]
    if cfg.rank == 0:
        name = osp.basename(cfg.log_dir)
        log_file = osp.join(cfg.log_dir, '{}_rank{}.log'.format(name, cfg.rank))
        log_local = osp.join('logs', osp.basename(log_file))
        cfg.log_file = log_file
        cfg.log_local = log_local
        reload(logging)
        logging.basicConfig(
            level=logging.INFO,
            format='[%(asctime)s] %(levelname)s: %(message)s',
            handlers=[
                logging.FileHandler(filename=log_local),
                logging.StreamHandler(stream=sys.stdout)])
        logging.info(cfg)
    
    # [data] training
    debug = cfg.world_size < 16  # NOTE: to avoid loading very large list files
    transforms = Compose([
        RandomCrop(cfg.resolution, cfg.min_crop, cfg.max_crop),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
    dataset = ImageFolder(
        root_dir=cfg.root_dir,
        list_file=cfg.list_file,
        transforms=transforms,
        rank=cfg.rank,
        world_size=1024 if debug else cfg.world_size)
    # skip distribution since it was already done in ImageFolder
    sampler = BatchSampler(
        dataset_size=len(dataset),
        batch_size=cfg.batch_size,
        num_replicas=cfg.world_size,
        rank=cfg.rank,
        shuffle=True,
        seed=cfg.seed)
    dataloader = DataLoader(
        dataset=dataset,
        batch_sampler=sampler,
        num_workers=cfg.num_workers,
        pin_memory=True,
        prefetch_factor=cfg.prefetch_factor)
    rank_iter = iter(dataloader)

    # [model] unet
    if cfg.get('unet') is not None:
        model = UNet(
            **cfg.unet,
            num_classes=None,
            use_fp16=cfg.use_fp16,
        ).to(gpu)
    model = DistributedDataParallel(model, device_ids=[gpu])
    if cfg.use_ema:
        model_ema = copy.deepcopy(model.module).eval().requires_grad_(False)
    
    # mark model size
    if cfg.rank == 0:
        logging.info(f'Created a model with {int(sum(p.numel() for p in model.parameters()) / (1024 ** 2))}M parameters')
    
    # diffusion
    if cfg.schedule == "zero_terminal_snr":
        cfg.schedule = "linear"
        betas = beta_schedule(cfg.schedule, cfg.num_timesteps)
        betas = enforce_zero_terminal_snr(betas)
    else:
        betas = beta_schedule(cfg.schedule, cfg.num_timesteps)
    diffusion = GaussianDiffusion(
        betas=betas,
        mean_type=cfg.mean_type,
        var_type=cfg.var_type,
        loss_type=cfg.loss_type,
        rescale_timesteps=False)
    
    # optimizer
    optimizer = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    # fp16 gradient scaler
    scaler = amp.GradScaler(enabled=cfg.use_fp16)

    # global variables
    viz_num = cfg.batch_size if cfg.get('viz_num') is None else cfg.viz_num 

    # run training
    for step in range(1, cfg.num_steps + 1):
        model.train().requires_grad_(True)

        # read batch
        batch = next(rank_iter)
        imgs = batch.to(gpu)
        imgs = imgs.permute(0,3,1,2)
        t = torch.randint(0, cfg.num_timesteps, (cfg.batch_size, ), dtype=torch.long, device=gpu)

        # forward
        with amp.autocast(enabled=cfg.use_fp16):
            loss = diffusion.loss(x0=imgs, t=t, model=model)
            loss = loss.mean()
        
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # ema update
        if cfg.use_ema:
            for p_ema, p in zip(model_ema.parameters(), model.module.parameters()):
                p_ema.copy_(p.lerp(p_ema, cfg.ema_decay))

        # metrics
        dist.all_reduce(loss)
        loss = loss / cfg.world_size

        # logging
        if cfg.rank == 0 and (step == 1 or step == cfg.num_steps or step % cfg.log_interval == 0):
            logging.info(f'Step: {step}/{cfg.num_steps} Loss: {loss.item():.3f} scale: {scaler.get_scale():.1f}')
        
        # visualization
        if step == 1 or step == cfg.num_steps or step % cfg.viz_interval == 0:
            # sample images
            with amp.autocast(enabled=cfg.use_fp16):
                gen_imgs = diffusion.p_sample_loop(
                    noise=torch.randn_like(imgs[:viz_num]),
                    model=model_ema if cfg.use_ema else model.eval().requires_grad_(False),
                    clamp=cfg.clamp,
                    guide_scale=None)
            
            viz_imgs = torch.stack([
                torch.cat(all_gather(gen_imgs[:viz_num]), dim=0)], dim=1).flatten(0, 1)

            # upload images
            if cfg.rank == 0:
                key = osp.join(cfg.log_dir, f'samples_step_{step}.jpg')
                tvutils.save_image(viz_imgs[:viz_num], key, nrow=cfg.nrow, normalize=True, range=(-1, 1))
                key = osp.join(cfg.log_dir, f'images_step_{step}.jpg')
                tvutils.save_image(imgs[:viz_num], key, nrow=cfg.nrow, normalize=True, range=(-1, 1))
        
        # checkpoint
        if cfg.rank == 0 and (step == cfg.num_steps or step % cfg.save_interval == 0):
            # current model
            key = osp.join(cfg.log_dir, f'checkpoints_step_{step}.pth')
            torch.save(model.module.state_dict(), key)
            # ema
            if cfg.use_ema:
                key = osp.join(cfg.log_dir, f'checkpoints_ema_step_{step}.pth')
                torch.save(model_ema.state_dict(), key)

    if cfg.rank == 0:
        # send a sign to oss to indicate the training is completed
        logging.info('Congratulations! The training is completed!')
        oss_key = osp.join(cfg.log_dir, 'completed.log')
        ops.put_object_from_file(bucket, log_file, log_local)

        # remove local files
        if osp.exists(log_local):
            os.remove(log_local)
    
    # synchronize to finish some processes
    torch.cuda.synchronize()
    dist.barrier()

if __name__ == '__main__':
    main()