import torch
import torch.nn as nn
from torch.utils import data
from torch import autograd
import torch.nn.functional as F

import pylab, os, random, math

from models.op import conv2d_gradfix


def d_logistic_loss(real_pred, fake_pred):
    real_loss = F.softplus(-real_pred)
    fake_loss = F.softplus(fake_pred)
    return real_loss.mean() + fake_loss.mean()

def d_r1_loss(real_pred, real_img):
    with conv2d_gradfix.no_weight_gradients():
        grad_real, = autograd.grad(
            outputs=real_pred.sum(), inputs=real_img, create_graph=True
        )
    grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
    return grad_penalty

def g_nonsaturating_loss(fake_pred):
    loss = F.softplus(-fake_pred).mean()
    return loss

def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
    noise = torch.randn_like(fake_img) / math.sqrt(
        fake_img.shape[2] * fake_img.shape[3]
    )
    grad, = autograd.grad(
        outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True, allow_unused=True
    )
    path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))

    path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)

    path_penalty = (path_lengths - path_mean).pow(2).mean()

    return path_penalty, path_mean.detach(), path_lengths

def idx2image(embed_ind, num_colors):
    """
    The utility function for plotting the vq label mask as an RGB image
    Input:  embed_ind: a BxHxW tensor, the inner-most dim contains a index value
            num_colors: number of embedding indices, such as 5,6,7
    """
    def get_colors(NUM_COLORS = 5):
        colors = []
        cm = pylab.get_cmap('gist_rainbow')
        for i in range(NUM_COLORS):
            colors.append( list(cm(1.*i/NUM_COLORS))[:3] )
        return colors

    colors = torch.tensor( get_colors(num_colors) )
    b, h, w = embed_ind.shape
    holder = torch.zeros(b, h, w, 3)
    for bidx in range(b):
        for hidx in range(h):
            for widx in range(w):
                holder[bidx][hidx][widx] = colors[embed_ind[bidx][hidx][widx]]
    return holder.permute(0, 3, 1, 2)


class CustomDDP(nn.parallel.DistributedDataParallel):
    """
    The wrapper class for multi-gpu training which makes it able to call 
    custom forward functions of a model other than the default "forward" function
    """
    def __init__(self, module, device_ids=None,
                 output_device=None, dim=0, broadcast_buffers=True,
                 process_group=None, bucket_cap_mb=25,
                 find_unused_parameters=False,
                 check_reduction=False) -> None:
        super().__init__(module, device_ids=device_ids, output_device=output_device, dim=dim, broadcast_buffers=broadcast_buffers, process_group=process_group, bucket_cap_mb=bucket_cap_mb, find_unused_parameters=find_unused_parameters, check_reduction=check_reduction)

        self.org_forward = self.module.forward

    def __getattr__(self, name):
        if name == 'module':
            return super().__getattr__('module')
        else:
            return getattr(self.module, name)

    def set_forward(self, name):
        custom_fn = getattr(self.module, name)
        self.module.forward = custom_fn
    
    def reset_forward(self):
        self.module.forward = self.org_forward

def ddp_runner(model, fn_name, *args, **kwargs):
    """
    The wrapper function that cooperates with the CustomDPP class
    """
    if isinstance(model, CustomDDP):
        model.set_forward(fn_name)
        out = model(*args, **kwargs)
        model.reset_forward()
        return out
    else:
        return getattr(model, fn_name)(*args, **kwargs)


###******** Training related helper functions ********
def data_sampler(dataset, shuffle, distributed):
    if distributed:
        return data.distributed.DistributedSampler(dataset, shuffle=shuffle)

    if shuffle:
        return data.RandomSampler(dataset)

    else:
        return data.SequentialSampler(dataset)


def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag


def accumulate(model1, model2, decay=0.999, buffer=True):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)

    if buffer:
        par1 = dict(model1.named_buffers())
        par2 = dict(model2.named_buffers())

        for k in par1.keys():
            try:
                par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
            except:
                pass     
        
def sample_data(loader):
    while True:
        for batch in loader:
            yield batch


def make_noise(batch, latent_dim, n_noise, device):
    if n_noise == 1:
        return torch.randn(batch, latent_dim, device=device)

    noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)

    return noises


def mixing_noise(batch, latent_dim, prob, device):
    if prob > 0 and random.random() < prob:
        return make_noise(batch, latent_dim, 2, device)

    else:
        return [make_noise(batch, latent_dim, 1, device)]


def set_grad_none(model, targets):
    for n, p in model.named_parameters():
        if n in targets:
            p.grad = None


def resize(image, size):
    if size==image.shape[-1]: return image
    return F.adaptive_avg_pool2d(image, size)

def get_dir(args, task_name=None):
    if task_name == None:
        task_name = args.name
    os.makedirs('experiments/'+task_name, exist_ok=True)
    os.makedirs('experiments/'+task_name+'/sample', exist_ok=True)
    os.makedirs('experiments/'+task_name+'/checkpoint', exist_ok=True)
    
    from shutil import copytree, copy
    try:
        for folder in ['models', 'utilstrain', 'utilseval']:
            copytree(folder, 'experiments/'+task_name+'/'+folder)
        for f in os.listdir('.'):
            if '.py' in f or '.sh' in f:
                copy(f, 'experiments/'+task_name+'/'+f)
    except:
        pass
    import json
    with open( os.path.join('experiments/'+task_name, 'args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

def detach(list_of_tensor):
    return [l.detach() for l in list_of_tensor]

def weight_scheduler(curr_iter, start_iter, total_iter):
    if curr_iter<start_iter: return 1
    elif curr_iter>total_iter: return 0
    return 1 - (curr_iter-start_iter) / (total_iter-start_iter)


###******** 
# self-supervision data augmentation methods to force the Generator to
# disentangle the pose-related features into certain layers 
###********
def scale_flip_augment(real_image):
    b, c, h, w = real_image.shape
    scale_size = int( h * (random.random()*0.2+1.05) )
    real_scale = F.interpolate(real_image, size=scale_size, mode='bilinear')
    crop_x, crop_y = random.randint(0, scale_size-h-1), random.randint(0, scale_size-w-1)
    real_scale = real_scale[:,:,crop_x:crop_x+h, crop_y:crop_y+w]
    if random.randint(0, 1) == 1:
        real_scale = torch.flip(real_scale, (3,))
    return torch.cat([real_image[:b//2], real_scale[b//2:]]), \
           torch.cat([real_scale[:b//2], real_image[b//2:]]) 


def random_cutout_augment(real_image):
    b, c, h, w = real_image.shape

    new_image = real_image.clone()
    eye_part = [0.3*h, 0.4*h, 0.2*w, 0.3*w]
    eye_size = [int(0.2*h), int(0.6*w)]

    eye_h = int(eye_part[0] + random.random() * (eye_part[1] - eye_part[0]))
    eye_w = int(eye_part[2] + random.random() * (eye_part[3] - eye_part[2]))
    new_image[:,:, eye_h:eye_h+eye_size[0], eye_w:eye_w+eye_size[1] ] = 0

    mouth_part = [0.6*h, 0.7*h, 0.3*w, 0.4*w]
    mouth_size = [int(0.3*h), int(0.3*h)]

    mouth_h = int(mouth_part[0] + random.random() * (mouth_part[1] - mouth_part[0]))
    mouth_w = int(mouth_part[2] + random.random() * (mouth_part[3] - mouth_part[2]))
    new_image[:,:, mouth_h:mouth_h+mouth_size[0], mouth_w:mouth_w+mouth_size[1] ] = 0
    return new_image
