import os.path
from tqdm import tqdm
from evaluate.reasoning_dataloader import *
import torchvision.transforms as T
from evaluate.mae_utils import *
import argparse
from pathlib import Path
from evaluate.segmentation_utils import *
from PIL import Image
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
from trainer.train_models import _generate_result_by_knn_indice_weight_prob, MetaTrn
import random
from evaluate_detection.canvas_ds import CanvasDataset4Val
from evaluate_detection.canvas_ds import CanvasDataset4ValKNN
from evaluate_detection.box_ops import to_rectangle


def get_args():
    parser = argparse.ArgumentParser('PANICL inference for detection', add_help=False)
    parser.add_argument('--mae_model', default='mae_vit_large_patch16', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--output_dir', default=f'./output_samples')
    parser.add_argument('--device', default='cuda:0',
                        help='device to use for training / testing')
    parser.add_argument('--base_dir', default='./pascal-5i', help='pascal base dir')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--t', default=[0, 0, 0], type=float, nargs='+')
    parser.add_argument('--task', default='detection', choices=['segmentation', 'detection'])
    parser.add_argument('--ckpt', default='./weights/checkpoint-1000.pth', help='model checkpoint')
    parser.add_argument('--dataset_type', default='pascal_det')
    parser.add_argument('--fold', default=0, type=int)
    parser.add_argument('--split', default='trn', type=str)
    parser.add_argument('--purple', default=0, type=int)
    parser.add_argument('--flip', default=0, type=int)
    parser.add_argument('--feature_name', default='features_vit-laion2b_pixel-level_val_all_detection', type=str)
    parser.add_argument('--percentage', default='', type=str)
    parser.add_argument('--cluster', action='store_true')
    parser.add_argument('--random', action='store_true')
    parser.add_argument("--random-example", action='store_true',
                        help="whether using random/in-context pair retriever.")
    parser.add_argument('--ensemble', action='store_true')
    parser.add_argument('--save-examples', action='store_true', help='whether save the example in val')
    parser.add_argument('--cls-base', action='store_true')
    parser.add_argument("--k", type=int, default=5,
                        help="knn number.")

    parser.add_argument("--batch-size", type=int, default=1,
                        help="Number of images sent to the network in one step.")
    parser.add_argument("--arr", type=str, default='a1',
                        help="the setting of arrangements of canvas")
    parser.add_argument("--n-shot", type=int, default=4,
                        help="Number of images for prompt pool.")
    parser.add_argument("--temp", type=float, default=1.0,
                        help="Temperature scaling factor")
    parser.add_argument("--alpha", type=float, default=0.7,
                        help="Balanced factor")

    return parser


def test(args):
    padding = 1
    image_transform = T.Compose(
        [T.Resize((224 // 2 - padding, 224 // 2 - padding), 3),
         T.ToTensor()])
    mask_transform = T.Compose(
        [T.Resize((224 // 2 - padding, 224 // 2 - padding), 3),
         T.ToTensor()])
   
    train_dataset = {
        'pascal_det': CanvasDataset4ValKNN
    }[args.dataset_type](args.base_dir, fold=args.fold, split=args.split, image_transform=image_transform,
                         mask_transform=mask_transform, flipped_order=args.flip, purple=args.purple,
                         random=args.random, cluster=args.cluster, feature_name=args.feature_name,
                         percentage=args.percentage, seed=args.seed, arr=args.arr, n_shot=args.n_shot,
                         random_example=args.random_example)

    dataloaders = {}
    dataloaders['train'] = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)

    print('train datalaoder: ', len(dataloaders['train']))
    print("load data over")
    
    # MAE_VQGAN model
    vqgan = prepare_model(args.ckpt, arch=args.mae_model)
    PANICL = MetaTrn(args=args, vqgan=vqgan.to(args.device), arr=args.arr)

    PANICL.to(args.device)
    PANICL.eval()

    eg_save_path = f'{args.output_dir}/output_examples/{args.task}_fold_{args.fold}'

    print(f'We adopt the arrangement of {args.arr}.')
    print("*" * 50)

    datastore = {}
    support_name_dic = {}
    eval_dict = {'iou': 0, 'color_blind_iou': 0, 'accuracy': 0}

    for i, batch_list in enumerate(tqdm(dataloaders['train'])):
        for data in batch_list:
            (support_img, support_mask, query_img, query_mask, support_name, query_name_val, support_name_datastore,
             grid_stack) = (data['support_img'], data['support_mask'], data['query_img'], data['query_mask'], data['support_name'],
                 data['query_name_val'][0], data['support_name_datastore'][0], data['grid_stack'])
            support_img = support_img.to(args.device, dtype=torch.float32)
            support_mask = support_mask.to(args.device, dtype=torch.float32)
            query_img = query_img.to(args.device, dtype=torch.float32)
            query_mask = query_mask.to(args.device, dtype=torch.float32)
            grid_stack = grid_stack.to(args.device, dtype=torch.float32)
            canvas_label, canvas_pred = PANICL.form_cavas(support_img, support_mask, query_img, query_mask, grid_stack)

            data_store_dic, support_name = PANICL.forward_meta_test_indice(canvas_pred, support_name)

            if f"{query_name_val}_{support_name_datastore}" not in datastore:
                datastore[f"{query_name_val}_{support_name_datastore}"] = data_store_dic
                support_name_dic[query_name_val] = support_name_datastore
            else:
                for keys in datastore[f"{query_name_val}_{support_name_datastore}"].keys():
                    datastore[f"{query_name_val}_{support_name_datastore}"][keys].extend(data_store_dic[keys])

    print(f'datastore shape: ', len(datastore.keys()))
    print(f'datastore per samples: ', len(datastore[list(datastore)[1]]['105']))

    val_dataset = {
        'pascal_det': CanvasDataset4Val
    }[args.dataset_type](args.base_dir, fold=args.fold, split=args.split, image_transform=image_transform,
                         mask_transform=mask_transform, flipped_order=args.flip, purple=args.purple, random=args.random,
                         cluster=args.cluster, feature_name=args.feature_name, percentage=args.percentage,
                         seed=args.seed, arr=args.arr, random_example=args.random_example)

    dataloaders['val'] = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    print('length of val dataset: ', len(val_dataset))

    examples_save_path = eg_save_path + f'/shot_{args.n_shot}'
    print('examples_save_path: ', examples_save_path)
    os.makedirs(examples_save_path, exist_ok=True)
    image_number = 0

    # Query branch.
    for i, data in enumerate(tqdm(dataloaders["val"])):
        len_dataloader = len(dataloaders["val"])
        support_img, support_mask, query_img, query_mask, support_name, query_name, support_name_datastore, grid_stack\
            = (data['support_img'], data['support_mask'], data['query_img'], data['query_mask'], data['support_name'],
               data['query_name'], data['support_name_datastore'], data['grid_stack'])
        support_img = support_img.to(args.device, dtype=torch.float32)
        support_mask = support_mask.to(args.device, dtype=torch.float32)
        query_img = query_img.to(args.device, dtype=torch.float32)
        query_mask = query_mask.to(args.device, dtype=torch.float32)
        grid_stack = grid_stack.to(args.device, dtype=torch.float32)

        canvas_label, canvas_pred = PANICL.form_cavas(support_img, support_mask, query_img, query_mask, grid_stack)

        key_list = []
        assert len(support_name_datastore) == len(query_name)
        for key in range(len(support_name_datastore)):
            key_list.append(f"{query_name[key]}_{support_name_datastore[key]}")

        original_image_list, generated_result_list = _generate_result_by_knn_indice_weight_prob(args, vqgan.to(
            args.device),
                                                                                         canvas_pred,
                                                                                         canvas_label,
                                                                                         args.arr,
                                                                                         datastore,
                                                                                         key_list,
                                                                                         args.k)

        for index in range(len(original_image_list)):
            sub_image = generated_result_list[index][113:, 113:]
            sub_image = round_image(sub_image, [WHITE, BLACK], t=args.t)
            generated_result_list[index][113:, 113:] = sub_image

            original_image = round_image(original_image_list[index], [WHITE, BLACK])

            generated_result = generated_result_list[index]
            if args.task == 'detection':
                generated_result = to_rectangle(generated_result)
            if args.save_examples:
                Image.fromarray((generated_result.cpu().numpy()).astype(np.uint8)).save(examples_save_path + f'/generated_image_{image_number}.png')
            current_metric = calculate_metric(args, original_image, generated_result, fg_color=WHITE, bg_color=BLACK)
            with open(os.path.join(examples_save_path, 'log.txt'), 'a') as log:
                log.write(str(image_number) + '\t' + str(current_metric) + '\n')
            image_number += 1

            for i, j in current_metric.items():
                eval_dict[i] += (j / len(val_dataset))

    print('eval_dict: ', eval_dict)
    with open(os.path.join(examples_save_path, 'log.txt'), 'a') as log:
        log.write('all\t' + str(eval_dict) + '\n')


if __name__ == '__main__':
    mp.set_start_method('spawn')
    args = get_args()

    args = args.parse_args()
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    test(args)
