import argparse
import logging
import os
from posixpath import join

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

from utils.data_vis import plot_img_and_mask
from dataset.basic_dataset import BasicDataset

from model.deeplabv3.deeplab import *
from model.unet import UNet
from model.coeffnet.coeffnet_deeplab import Coeffnet_Deeplab


def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = torch.cat([img, img], dim=0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output[0], dim=1)
        else:
            probs = torch.sigmoid(output[0])

        probs = probs.squeeze(0)

        tf = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(full_img.size[1]),
                transforms.ToTensor()
            ]
        )

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    return full_mask > out_threshold


def get_args():
    parser = argparse.ArgumentParser(description='Predict masks from input images',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--model', '-m', default='MODEL.pth',
                        metavar='FILE',
                        help="Specify the file in which the model is stored")
    parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
                        help='filenames of input images', required=True)
    parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
                        help='Filenames of ouput images')
    parser.add_argument('--viz', '-v', action='store_true',
                        help="Visualize the images as they are processed",
                        default=False)
    parser.add_argument('--no-save', '-n', action='store_true',
                        help="Do not save the output masks",
                        default=False)
    parser.add_argument('--mask-threshold', '-t', type=float,
                        help="Minimum probability value to consider a mask pixel white",
                        default=0.5)
    parser.add_argument('--scale', '-s', type=float,
                        help="Scale factor for the input images",
                        default=1)

    return parser.parse_args()


def get_output_filenames(args, in_files):
    out_files = []

    if not args.output:
        for f in in_files:
            pathsplit = os.path.splitext(f)
            out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1]))
    elif len(in_files) != len(args.output):
        logging.error("Input files and output files are not of the same length")
        raise SystemExit()
    else:
        out_files = args.output
    print(out_files)
    return out_files


def mask_to_image(mask):
    return Image.fromarray((mask * 255).astype(np.uint8))

def mask_on_img(img, mask, alpha=0.8):
    """overlap mask on img

    Args:
        img (PIL Image): input target rgb image, H*W*C
        mask (bool array): predict mask of img, H*W
        alpha (float, optional): transparent value of the mask. Defaults to 0.8.
    """
    res = np.array(img.copy())
    res[:, :, 2] = mask[:,:]*alpha*255 + (1-mask[:,:]*alpha)*res[:,:,2]
    return res

def parse_input(in_files):
    file_list = []
    for f in in_files:
        if os.path.isfile(f):
            file_list.append(f)
        elif os.path.isdir(f):
            file_under_dir = [os.path.join(f, file) for file in os.listdir(f)]
            file_list = file_list + file_under_dir
    return file_list


if __name__ == "__main__":
    args = get_args()
    in_files = parse_input(args.input)
    out_files = get_output_filenames(args, in_files)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # net = UNet(n_channels=3, n_classes=1)
    # net = DeepLab(num_classes = 1, backbone = 'resnet', output_stride = 16)
    net = Coeffnet_Deeplab("/home/pancy/IP/Object-Pursuit/Segmentation/Bases", device)

    logging.info("Loading model {}".format(args.model))

    logging.info(f'Using device {device}')
    net.to(device=device)
    net.load_state_dict(torch.load(args.model, map_location=device))

    logging.info("Model loaded !")

    for i, fn in enumerate(tqdm(in_files)):
        logging.info("\nPredicting image {} ...".format(fn))

        img = Image.open(fn)

        mask = predict_img(net=net,
                           full_img=img,
                           scale_factor=args.scale,
                           out_threshold=args.mask_threshold,
                           device=device)
        
        if not args.no_save:
            # out_fn = out_files[i]
            # result = mask_to_image(mask)
            # result.save(out_files[i])
            res = mask_on_img(img, mask)
            Image.fromarray(res).save(out_files[i])

            logging.info("Mask saved to {}".format(out_files[i]))

        if args.viz:
            logging.info("Visualizing results for image {}, close to continue ...".format(fn))
            plot_img_and_mask(img, mask)
