import os.path
from tqdm import tqdm
from trainer import val_pascal_dataloader
from evaluate.reasoning_dataloader import *
import torchvision
from evaluate.mae_utils import *
import argparse
from pathlib import Path
from evaluate.segmentation_utils import *
from PIL import Image
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
from trainer.train_models import _generate_result_for_large_canvas


def get_args():
    parser = argparse.ArgumentParser('Multi-example ICL baseline segmentation', add_help=False)
    parser.add_argument('--mae_model', default='mae_vit_large_patch16', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--output_dir', default=f'./output_samples/')
    parser.add_argument('--device', default='cuda:0',
                        help='device to use for training / testing')
    parser.add_argument('--base_dir', default='./pascal-5i', help='pascal base dir')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--t', default=[0, 0, 0], type=float, nargs='+')
    parser.add_argument('--task', default='segmentation', choices=['segmentation', 'detection'])
    parser.add_argument('--ckpt', default='./weights/checkpoint-1000.pth', help='model checkpoint')
    parser.add_argument('--dataset_type', default='pascal',
                        choices=['pascal', 'pascal_det'])
    parser.add_argument('--fold', default=0, type=int)
    parser.add_argument('--split', default='trn', type=str)
    parser.add_argument('--purple', default=0, type=int)
    parser.add_argument('--flip', default=0, type=int)
    parser.add_argument('--feature_name', default='features_vit-laion2b_feature-level_trn', type=str)
    parser.add_argument('--percentage', default='', type=str)
    parser.add_argument('--cluster', action='store_true')
    parser.add_argument('--random', action='store_true')
    parser.add_argument('--save-examples', action='store_true', help='whether save the visual examples')
    parser.add_argument("--batch-size", type=int, default=1,
                        help="Number of images sent to the network in one step.")
    parser.add_argument("--arr", type=str, default='a1',
                        help="the setting of arrangements of canvas")
    parser.add_argument("--n-shot", type=int, default=4,
                        help="Number of examples for large canvas.")
    parser.add_argument("--random-example", action='store_true',
                        help="whether using random/in-context pair retriever.")

    return parser


def test_for_generate_results(args):
    image_transform = torchvision.transforms.Compose(
        [torchvision.transforms.Resize((48, 48), 3),
         torchvision.transforms.ToTensor()])
    mask_transform = torchvision.transforms.Compose(
        [torchvision.transforms.Resize((48, 48), 3),
         torchvision.transforms.ToTensor()])

    val_dataset = {
        'pascal': val_pascal_dataloader.DatasetPASCALLargeCanvas,
    }[args.dataset_type](args.base_dir, fold=args.fold, split=args.split, image_transform=image_transform,
                         mask_transform=mask_transform,
                         flipped_order=args.flip, purple=args.purple, random=args.random, cluster=args.cluster,
                         feature_name=args.feature_name, percentage=args.percentage, seed=args.seed,
                         arr=args.arr, n_shot=args.n_shot, random_example=args.random_example)

    dataloaders = {}
    dataloaders['val'] = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    print('val datalaoder: ', len(dataloaders['val']))
    print("load data over")

    # MAE_VQGAN model
    vqgan = prepare_model(args.ckpt, arch=args.mae_model)

    setting = f'fold{args.fold}_{args.task}_{args.arr}_prompts_{args.n_shot}'
    eg_save_path = f'{args.output_dir}/output_examples/fold_{args.fold}_multi_example_ICL'
    os.makedirs(eg_save_path, exist_ok=True)

    eval_dict = {'iou': 0, 'color_blind_iou': 0, 'accuracy': 0}
    examples_save_path = eg_save_path + f'/{setting}/'
    os.makedirs(examples_save_path, exist_ok=True)

    with open(os.path.join(examples_save_path, 'log.txt'), 'w') as log:
        log.write(str(args) + '\n')

    image_number = 0
    # Inference phase
    for i, data in enumerate(tqdm(dataloaders["val"])):
        len_dataloader = len(dataloaders["val"])
        support_img, support_mask, query_img, query_mask, support_name, query_name, grid_stack = \
            data['support_imgs'], data['support_masks'], data['query_img'], data['query_mask'], data['support_name'], data['query_name'], data['grid_stack']
        grid_stack = grid_stack.to(args.device, dtype=torch.float32)

        canvas_pred = grid_stack.clone().float().to(args.device)
        canvas_label = canvas_pred.clone()
        if args.dataset_type != 'pascal_det':
            # convert to imagenet distribution.
            imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).to(args.device)
            imagenet_std = torch.tensor([0.229, 0.224, 0.225]).to(args.device)
            canvas_pred = (canvas_pred - imagenet_mean[:, None, None]) / imagenet_std[:, None, None]
            canvas_label = (canvas_label - imagenet_mean[:, None, None]) / imagenet_std[:, None, None]

        original_image_list, generated_result_list = _generate_result_for_large_canvas(args, vqgan.to(args.device),
                                                                                       canvas_pred, canvas_label,
                                                                                       args.arr)
        for index in range(len(original_image_list)):
            if args.save_examples:
                # for qualitative comparison.
                Image.fromarray(generated_result_list[index]).save(
                    examples_save_path + f'generated_image_{image_number}.png')
            original_image = round_image(original_image_list[index], [WHITE, BLACK])
            generated_result = round_image(generated_result_list[index], [WHITE, BLACK], t=args.t)

            current_metric = calculate_large_canvas_metric(args, original_image, generated_result, fg_color=WHITE, bg_color=BLACK)
            with open(os.path.join(examples_save_path, 'log.txt'), 'a') as log:
                log.write(str(image_number) + '\t' + str(current_metric) + '\n')
            image_number += 1
            for i, j in current_metric.items():
                eval_dict[i] += (j / len(val_dataset))

    print('val metric: {}'.format(eval_dict))
    with open(os.path.join(examples_save_path, 'log.txt'), 'a') as log:
        log.write('all\t' + str(eval_dict) + '\n')


if __name__ == '__main__':
    mp.set_start_method('spawn')
    args = get_args()

    args = args.parse_args()
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    test_for_generate_results(args)
