import torch
import numpy as np
import random
from dataset import CoralDataset
from torch.utils.data import DataLoader
from torchvision import transforms
from networks.benthiq import BenthIQ
import os
import argparse
from utils import *

parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
                    default='test_data', help='root dir for data')
parser.add_argument('--dataset', type=str,
                    default='Coral', help='experiment_name')
parser.add_argument('--num_classes', type=int,
                    default=4, help='output channel of network')
parser.add_argument('--output_dir', type=str, default="output", help='output dir')       
parser.add_argument('--batch_size', type=int,
                    default=24, help='batch_size per gpu')
parser.add_argument('--seed', type=int,
                    default=1234, help='random seed')
parser.add_argument('--img_size', type=int,
                    default=224, help='input patch size of network input')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--model_weights', type=str, help='directory of model weights to load from')
parser.add_argument('--mode', default="val", type=str, help='test or validation')

args = parser.parse_args()

def visualize(args, dataloaders, net, images = None, masks = None):
    ## Display the predictions
    if images==None and masks==None:
      sample = next(iter(dataloaders))
      images, masks = sample['image'], sample['label']

    idx = 0
    original_image = np.array(images[idx, :, :, :]).astype(int)
    ground_truth_mask = masks[idx,:,:]
    predicted_mask = test_single_volume(images[:, :, :, :], net, args.num_classes)[idx, :,0 :]
    if original_image.shape[0] == 1:
      original_image = original_image[0]
    error_map = create_error_map(ground_truth_mask, predicted_mask)
    display_images(original_image=original_image, ground_truth_mask=ground_truth_mask, predicted_mask=predicted_mask, error_map=error_map)

if __name__ == "__main__":
    num_classes = args.num_classes
    batch_size = args.batch_size * args.n_gpu
    
    if torch.cuda.is_available():
        net = BenthIQ(img_size=args.img_size, num_classes=args.num_classes).cuda()
        net.load_state_dict(torch.load(args.model_weights))
    else: 
        net = BenthIQ(img_size=args.img_size, num_classes=args.num_classes)
        net.load_state_dict(torch.load(args.model_weights, map_location=torch.device('cpu')))
    net.eval()

    db_train = CoralDataset(image_dir=os.path.join(args.root_path, 'imgs'),
                            mask_dir=os.path.join(args.root_path, 'masks'),
                            image_names=sorted(os.listdir(os.path.join(args.root_path, 'imgs'))),
                            mask_names=sorted(os.listdir(os.path.join(args.root_path, 'masks'))),
                            transform=transforms.Compose(
                                    []))

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
                            worker_init_fn=worker_init_fn)
    visualize(args, trainloader, net)