import os
import sys
sys.path.append("../")
import pickle
import numpy as np
import matplotlib.pyplot as plt
from data.MNIST_test import LitMNIST
from data.KMNIST_test import LitKMNIST
from data.FashionMNIST_test import LitFashionMNIST
from src.CifarCombinedLeNet import MagnitudeLayer,MagnitudeLayer_quant, MagnitudeLayer_grey
from src.MagLeNet import MagnitudeLayer_abl_1,MagnitudeLayer_abl_2,MagnitudeLayer_abl_3
from mpl_toolkits.axes_grid1 import ImageGrid
from skimage import feature
import wandb
import argparse
from data.CIFAR10_test import LitCIFAR
from PIL import Image
import torch
import torchvision.transforms.functional as F
from time import time
from scripts import EdgeDetection
from scripts.paper import patches_big

def dice(im1, im2):

    im1 = np.asarray(im1).astype(np.bool)
    im2 = np.asarray(im2).astype(np.bool)

    if im1.shape != im2.shape:
        raise ValueError("Shape mismatch: im1 and im2 must have the same shape.")

    # Compute Dice coefficient
    intersection = np.logical_and(im1, im2)
    num_pixels = im1.shape[1]*im1.shape[2]

    return 2*intersection.sum((1,2)) / (im1.sum((1,2))+im2.sum((1,2)))

def evaluate_patch_one_image(data_path = "../data/x_ray", image="00000001_000.png", greyscale = True ,x_patches = 32 ,y_patches = 32,overlap = 1,x_rescale_factor = 1,y_rescale_factor = 1,l_pixel=100,l_grid=1.,threshold = 0.6, **kwargs):

    img_name = image.split('.')[0]

    mag_layer = patches_big.MagnitudeLayer(l_pixel=l_pixel,l_grid=l_grid)

    img = Image.open(os.path.join(data_path,image))
    img_arr = np.asarray(img)

    print(f'The image {image} has dimensions (h,w,c): {img_arr.shape}')

    if len(img_arr.shape) == 3:
        img_t = torch.from_numpy(img_arr).permute(2,0,1)/255.
        if greyscale:
            img_t = torch.mean(img_t,axis=0)
            img_t = img_t.unsqueeze(0).unsqueeze(0)
            img_t_small = F.resize(img_t,(int(img_t.shape[2]*y_rescale_factor),int(img_t.shape[3]*x_rescale_factor)))
        else:
            img_t = img_t.unsqueeze(0)
            img_t_small = F.resize(img_t,(int(img_t.shape[2]*y_rescale_factor),int(img_t.shape[3]*x_rescale_factor)))
    else:
        img_t = torch.from_numpy(img_arr)/255.
        img_t = img_t.unsqueeze(0).unsqueeze(0)
        img_t_small = F.resize(img_t,(int(img_t.shape[2]*y_rescale_factor),int(img_t.shape[3]*x_rescale_factor)))

    print(f'The rescaled image {image} has dimensions (b,c,h,w): {img_t_small.shape}')

    y_shape = img_t_small.shape[2]
    x_shape = img_t_small.shape[3]

    tic = time()
    patches,patch_grid = patches_big.make_patches(img_t_small.permute(0,2,3,1).squeeze().numpy(),(x_patches,y_patches),overlap=overlap)
    patches_vecs = []

    for p in patches:
        tmp = torch.from_numpy(p).reshape(1,1,*p.shape).contiguous()
        mag_patch = mag_layer.forward(tmp).squeeze().numpy()
        patches_vecs.append(mag_patch)
    mag_img_approx = patches_big.stitch_patches(patches_vecs,patch_grid,(x_patches,y_patches),overlap=overlap)
    mag_img_approx = patches_big.min_max(mag_img_approx)
    mag_img_thres = (mag_img_approx>threshold).astype(float)
    toc = time()
    print(f'Elapsed time during approximate magnitude calculation: {(toc-tic):.2f}s')

    canny = feature.canny(img_t[0,0].numpy())
    
    return img_t_small.squeeze().numpy(), mag_img_thres, canny


def evaluate_patch(threshold , l_pixel ,data_path = "../data/x_ray", test = False,  **kwargs):
    listdir = os.listdir(data_path)
    
    cutof = int(0.8*len(listdir))
    if test:
        listdir = listdir[cutof:]
    else:
        listdir = listdir[:cutof]
    imgs = []
    mag_imgs  = []
    canny_imgs = []
    for f in listdir:
        img, mag_img, canny = evaluate_patch_one_image(image = f, threshold = threshold, l_pixel = l_pixel,  **kwargs)
        imgs.append(img)
        mag_imgs.append(mag_img)
        canny_imgs.append(canny)

    return np.stack(canny_imgs), np.stack(mag_imgs), np.stack(imgs)


def evaluate(l_grid, l_pixel, threshold, dataset, test = False):
    
    if dataset == "CIFAR":
        data = LitCIFAR(padding = True)
        size_img = 34
        layer_0 = MagnitudeLayer_grey(l_grid=l_grid,p=1,power=1,l_pixel = l_pixel)
    else:
        size_img = 30
        data = LitFashionMNIST(padding = True)
        layer_0 = MagnitudeLayer(l_grid= l_grid,p=1,power=1,l_pixel = l_pixel)
    data.setup()

    if test:
        data_loader = data.test_dataloader(bs=10)
    else:
        data_loader = data.train_dataloader(bs=10)

    imgs_mag = [] 
    imgs_canny = []
    imgs_list = []
    
    def process_img(x_,dataset_name):
        if dataset_name =="FashionMNIST":
            return x_.squeeze().view(size_img,size_img).numpy(), x_[None,...]
        elif dataset_name == "CIFAR":
            temp =  x_.permute(1,2,0).numpy()
            im = Image.fromarray((temp*255).astype(np.uint8))
            im_bw = im.convert('L')
            return np.array(im_bw), x_[None,...]

    for i,(x,_) in enumerate(data_loader):
        canny_list = []
        magnitude_vec = []
        for j in range(x.shape[0]):
            img_ = process_img(x[j],dataset)
            canny_list.append(feature.canny(img_[0]).astype(float))
            tmp = layer_0.forward(img_[1]).squeeze().view(size_img,size_img).numpy()
            tmp =(((tmp-np.min(tmp))/(np.max(tmp)-np.min(tmp))) > threshold).astype(float)
            magnitude_vec.append(tmp)
            imgs_list.append(x[j])

        canny_images = np.stack(canny_list)
        magnitude_images = np.stack(magnitude_vec)

        imgs_mag.append(magnitude_images)
        imgs_canny.append(canny_images)
        if i>=10:
            break
    imgs_canny = np.concatenate(imgs_canny)[...,1:-1,1:-1]
    imgs_mag = np.concatenate(imgs_mag)[...,1:-1,1:-1]
    imgs = np.stack(imgs_list)[...,1:-1,1:-1]

    return imgs_canny, imgs_mag, np.stack(imgs_list)

def main(args):
    run = wandb.init(project = "magnitude", entity = "edebrouwer",name = "edgedetection_"+args.dataset, config = vars(args))

    if args.dataset=='Xray':
        imgs_canny, imgs_mag, imgs = evaluate_patch(**vars(args), test = False)
        dice_metric = dice(imgs_canny,imgs_mag).mean()
    else:
        imgs_canny, imgs_mag, imgs = evaluate(**vars(args),test = False)
        dice_metric = dice(imgs_canny,imgs_mag).mean()
    
    wandb.run.summary["train_dice"] = dice_metric
    run.finish()


if __name__=="__main__":
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--l_grid', default = 1, type = int)
    parser.add_argument('--l_pixel', default = 1, type = int)
    parser.add_argument('--threshold', default = 0.5, type = float)
    parser.add_argument('--dataset', default = "FashionMNIST", type = str)
    args = parser.parse_args()
    
    main(args)
