
import numpy as np
import random, sys, os, time, glob, math, itertools, yaml, pickle
import parse
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from PIL import ImageFilter
from torchvision import transforms

from functools import partial
from scipy import ndimage

import IPython

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
USE_CUDA = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor

EXPERIMENT, BASE_DIR = open("config/jobinfo.txt").read().strip().split(', ')
JOB = "_".join(EXPERIMENT.split("_")[0:-1])

MODELS_DIR = f"{BASE_DIR}/shared/models"
DATA_DIRS = [f"/taskonomy-data/taskonomydata"]
RESULTS_DIR = f"/workspace/shared/results_{EXPERIMENT}"
SHARED_DIR = f"{BASE_DIR}/shared"
OOD_DIR = f"/scratch-data/ood_set_expanded/"
USE_RAID = False

os.system(f"mkdir -p {RESULTS_DIR}")


def both(x, y):
    x = dict(x.items())
    x.update(y)
    return x

def elapsed(last_time=[time.time()]):
    """ Returns the time passed since elapsed() was last called. """
    current_time = time.time()
    diff = current_time - last_time[0]
    last_time[0] = current_time
    return diff

def cycle(iterable):
    """ Cycles through iterable without making extra copies. """
    while True:
        for i in iterable:
            yield i

def average(arr):
    return sum(arr) / len(arr)

# def random_resize(iterable, vals=[128, 192, 256, 320]):
#    """ Cycles through iterable while randomly resizing batch values. """
#     from transforms import resize
#     while True:
#         for X, Y in iterable:
#             val = random.choice(vals)
#             yield resize(X.to(DEVICE), val=val).detach(), resize(Y.to(DEVICE), val=val).detach()


def get_files(exp, data_dirs=DATA_DIRS, recursive=False):
    """ Gets data files across mounted directories matching glob expression pattern. """
    # cache = SHARED_DIR + "/filecache_" + "_".join(exp.split()).replace(".", "_").replace("/", "_").replace("*", "_") + ("r" if recursive else "f") + ".pkl"
    # print ("Cache file: ", cache)
    # if os.path.exists(cache):
    #     return pickle.load(open(cache, 'rb'))

    files, seen = [], set()
    for data_dir in data_dirs:
        for file in glob.glob(f'{data_dir}/{exp}', recursive=recursive):
            if file[len(data_dir):] not in seen:
                files.append(file)
                seen.add(file[len(data_dir):])

    # pickle.dump(files, open(cache, 'wb'))
    return files


def get_finetuned_model_path(parents):
    if BASE_DIR == "/":
        return f"{RESULTS_DIR}/" + "_".join([parent.name for parent in parents[::-1]]) + ".pth"
    else:
        return f"{MODELS_DIR}/finetuned/" + "_".join([parent.name for parent in parents[::-1]]) + ".pth"


def plot_images(model, logger, test_set, dest_task="normal",
        ood_images=None, show_masks=False, loss_models={},
        preds_name=None, target_name=None, ood_name=None,
    ):

    from task_configs import get_task, ImageTask

    test_images, preds, targets, losses, _ = model.predict_with_data(test_set)

    if isinstance(dest_task, str):
        dest_task = get_task(dest_task)

    if show_masks and isinstance(dest_task, ImageTask):
        test_masks = ImageTask.build_mask(targets, dest_task.mask_val, tol=1e-3)
        logger.images(test_masks.float(), f"{dest_task}_masks", resize=64)

    dest_task.plot_func(preds, preds_name or f"{dest_task.name}_preds", logger)
    dest_task.plot_func(targets, target_name or f"{dest_task.name}_target", logger)

    if ood_images is not None:
        ood_preds = model.predict(ood_images)
        dest_task.plot_func(ood_preds, ood_name or f"{dest_task.name}_ood_preds", logger)

    for name, loss_model in loss_models.items():
        with torch.no_grad():
            output = loss_model(preds, targets, test_images)
            if hasattr(output, "task"):
                output.task.plot_func(output, name, logger, resize=128)
            else:
                logger.images(output.clamp(min=0, max=1), name, resize=128)


def gaussian_filter(channels=3, kernel_size=5, sigma=1.0, device=0):

    x_cord = torch.arange(kernel_size).float()
    x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
    y_grid = x_grid.t()
    xy_grid = torch.stack([x_grid, y_grid], dim=-1)

    mean = (kernel_size - 1) / 2.
    variance = sigma ** 2.
    gaussian_kernel = (1. / (2. * math.pi * variance)) * torch.exp(
        -torch.sum((xy_grid - mean) ** 2., dim=-1) / (2 * variance)
    )
    gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
    gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
    gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)

    return gaussian_kernel


def motion_blur_filter(kernel_size=15):
    channels = 3
    kernel_motion_blur = torch.zeros((kernel_size, kernel_size))
    kernel_motion_blur[int((kernel_size - 1) / 2), :] = torch.ones(kernel_size)
    kernel_motion_blur = kernel_motion_blur / kernel_size
    kernel_motion_blur = kernel_motion_blur.view(1, 1, kernel_size, kernel_size)
    kernel_motion_blur = kernel_motion_blur.repeat(channels, 1, 1, 1)
    return kernel_motion_blur


def sobel_kernel(x):
    def sobel_transform(x):
        image = x.data.cpu().numpy().mean(axis=0)
        blur = ndimage.filters.gaussian_filter(image, sigma=2, )
        sx = ndimage.sobel(blur, axis=0, mode='constant')
        sy = ndimage.sobel(blur, axis=1, mode='constant')
        sob = np.hypot(sx, sy)
        edge = torch.FloatTensor(sob).unsqueeze(0)
        return edge

    x = torch.stack([sobel_transform(y) for y in x], dim=0)
    return x.to(DEVICE).requires_grad_()


def binarized_kernel(x):
    def binarized_transform(x):
        image = (x>0.5)*1.0
        return image.float()
    x = torch.stack([binarized_transform(y) for y in x], dim=0)
    return x.to(DEVICE).requires_grad_()  


class SobelKernel(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return sobel_kernel(x)


def laplace_kernel(x):
    def laplace_transform(x):
        image = x.data.cpu().numpy().mean(axis=0) 
        blur = ndimage.filters.gaussian_filter(image, sigma=2, )
        lap = ndimage.laplace(blur) 
        edge = torch.FloatTensor(lap).unsqueeze(0)
        return edge
    x = torch.stack([laplace_transform(y) for y in x], dim=0)
    return x.to(DEVICE).requires_grad_()

def gauss_kernel(x):
    def gauss_transform(x):
        x_cpu = x.data.cpu().numpy()
        r, g, b = x_cpu[0,:], x_cpu[1,:], x_cpu[2,:]
        fr, fg, fb = ndimage.filters.gaussian_filter(r, sigma=4), ndimage.filters.gaussian_filter(g, sigma=4), ndimage.filters.gaussian_filter(b, sigma=4)
        fr, fg, fb = fr[None,:], fg[None,:], fb[None,:]
        x_f = np.concatenate( (fr,fg,fb), axis=0)
        image = torch.FloatTensor(x_f)
        return image
    x = torch.stack([gauss_transform(y) for y in x], dim=0)
    return x.to(DEVICE).requires_grad_()


emboss_weights = torch.tensor([[0.,0.,0.],[0.,1.0,0.],[-1.,0.,0.]])
emboss_weights = emboss_weights.view(1,1,3,3).cuda()
def emboss_kernel(x):
    def emboss_transform(x):
        x = x.mean(0,keepdim=True)
        x = (x*255).round().unsqueeze(0)
        image = F.conv2d(x,emboss_weights,padding=1)
        image = image + 128.0
        image = image.clamp(min=0.0,max=255.0)
        image = image / 255.0
        return image.squeeze(0)
    x = torch.stack([emboss_transform(y) for y in x], dim=0)
    return x.to(DEVICE).requires_grad_()

# def emboss_kernel(x):
#     def emboss_transform(x):
#         x = x.mean(0,keepdim=True)
#         image = transforms.ToPILImage()(x.cpu())
#         imageEmboss = image.filter(ImageFilter.EMBOSS)
#         image = transforms.ToTensor()(imageEmboss)
    
#         return image

#     x = torch.stack([emboss_transform(y) for y in x], dim=0)
#     return x.to(DEVICE).requires_grad_() 

def greyscale(x):
    def grey_transform(x):
        return x.mean(0,keepdim=True)
    x = torch.stack([grey_transform(y) for y in x], dim=0)
    return x.to(DEVICE).requires_grad_()


from pytorch_wavelets import DWTForward, DWTInverse

xfm = DWTForward(J=3, mode='zero', wave='db1').cuda()

def wav_kernel(x):
    def wav_transform(x):
        x_h, x_l = xfm(x.unsqueeze(0))
        x_h = F.interpolate(x_h, size=256, mode='bilinear')
        x_l_0, x_l_1, x_l_2 = F.interpolate(x_l[0][:,:,0,:], size=256, mode='bilinear'), F.interpolate(x_l[1][:,:,0,:], size=256, mode='bilinear') , F.interpolate(x_l[2][:,:,0,:], size=256, mode='bilinear')
        x_final = torch.cat((x_h.squeeze(),x_l_0.squeeze(),x_l_1.squeeze(),x_l_2.squeeze()), dim=0)

        return x_final

    x = torch.stack([wav_transform(y) for y in x], dim=0)
    return x.to(DEVICE).requires_grad_()

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed) # cpu  vars
    torch.cuda.manual_seed_all(seed) # gpu vars
