import os.path
from tqdm import tqdm
import random
from trainer import val_pascal_dataloader
from trainer import train_fewshot_pascal_dataloader
from evaluate.reasoning_dataloader import *
import torchvision.transforms as T
from evaluate.mae_utils import *
import argparse
from pathlib import Path
from evaluate.segmentation_utils import *
from PIL import Image
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
from trainer.train_models import MetaTrn, _generate_result_by_knn_indice_weight_prob


def get_args():
    parser = argparse.ArgumentParser('PANICL inference for foreground segmentation', add_help=False)
    parser.add_argument('--mae_model', default='mae_vit_large_patch16', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument("--mode", type=str, default='no_vp',
                        choices=['no_vp', 'spimg_spmask', 'spimg', 'spimg_qrimg', 'qrimg', 'spimg_spmask_qrimg', 'icl'],
                        help="mode of adding vp on img.")
    parser.add_argument("--anchor-mode", type=str, default='query',
                        choices=['itself', 'query', 'random', 'seq'],
                        help="The anchor mode for building prompt pool.")
    parser.add_argument('--output_dir', default=f'./output_samples')
    parser.add_argument('--device', default='cuda:0',
                        help='device to use for training / testing')
    parser.add_argument('--base_dir', default='./pascal-5i', help='pascal base dir')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--t', default=[0, 0, 0], type=float, nargs='+')
    parser.add_argument('--task', default='segmentation', choices=['segmentation', 'detection'])
    parser.add_argument('--ckpt', default='./weights/checkpoint-1000.pth', help='model checkpoint')
    parser.add_argument('--dataset_type', default='pascal')
    parser.add_argument('--fold', default=0, type=int)
    parser.add_argument('--split', default='trn', type=str)
    parser.add_argument('--purple', default=0, type=int)
    parser.add_argument('--flip', default=0, type=int)
    parser.add_argument('--feature_name', default='features_vit-laion2b_pixel-level_val', type=str)
    parser.add_argument('--percentage', default='', type=str)
    parser.add_argument('--cluster', action='store_true')
    parser.add_argument('--random', action='store_true')
    parser.add_argument("--random-example", action='store_true',
                        help="whether using random/in-context pair retriever.")
    parser.add_argument('--save-examples', action='store_true', help='whether save the visual examples')
    parser.add_argument('--cls-base', action='store_true')
    parser.add_argument("--k", type=int, default=5,
                        help="number of knn.")

    # testing settings for PANICL
    parser.add_argument("--batch-size", type=int, default=1,
                        help="Number of images sent to the network in one step.")
    parser.add_argument("--arr", type=str, default='a1',
                        help="the setting of arrangements of canvas")
    parser.add_argument('--save-model-path',
                        help='model checkpoint')
    parser.add_argument("--n-shot", type=int, default=4,
                        help="Number of images for fsl.")
    parser.add_argument("--temp", type=float, default=1.0,
                        help="Temperature scaling factor")
    parser.add_argument("--alpha", type=float, default=1.0,
                        help="Balanced factor")

    return parser


def test(args):
    padding = 1
    image_transform = T.Compose(
        [T.Resize((224 // 2 - padding, 224 // 2 - padding), 3),
         T.ToTensor()])
    mask_transform = T.Compose(
        [T.Resize((224 // 2 - padding, 224 // 2 - padding), 3),
         T.ToTensor()])

    train_dataset = {
        'pascal': train_fewshot_pascal_dataloader.DatasetPASCAL,
    }[args.dataset_type](args.base_dir, fold=args.fold, split=args.split,
                         image_transform=image_transform,
                         mask_transform=mask_transform,
                         flipped_order=args.flip, purple=args.purple, random=args.random, cluster=args.cluster,
                         feature_name=args.feature_name, percentage=args.percentage, seed=args.seed, arr=args.arr,
                         n_shot=args.n_shot, cls_base=args.cls_base, random_example=args.random_example,
                         anchor_mode=args.anchor_mode)

    dataloaders = {}
    dataloaders['train'] = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)

    print('train datalaoder: ', len(dataloaders['train']))
    print("load data over")

    # MAE_VQGAN model
    vqgan = prepare_model(args.ckpt, arch=args.mae_model)

    PANICL = MetaTrn(args=args, vqgan=vqgan.to(args.device), arr=args.arr)
    PANICL.to(args.device)
    PANICL.eval()

    eg_save_path = f'{args.output_dir}/output_examples/{args.task}_fold_{args.fold}_shot_{args.n_shot}'

    print(f'We adopt the arrangement of {args.arr}.')
    print("*" * 50)

    datastore = {}
    support_name_dic = {}
    eval_dict = {'iou': 0, 'color_blind_iou': 0, 'accuracy': 0}

    # Prompt pooling branch for datastore.
    for i, batch_list in enumerate(tqdm(dataloaders['train'])):
        for data in batch_list:
            (support_img, support_mask, query_img, query_mask, support_name, support_class, query_name_val,
             support_name_datastore, query_class_datastore, grid_stack) = (
                data['support_img'], data['support_mask'], data['query_img'], data['query_mask'],
                data['support_name'], data['suppor_class'], data['query_name_val'][0],
                data['support_name_datastore'][0],
                data['query_class_datastore'][0], data['grid_stack'])
            
            grid_stack = grid_stack.to(args.device, dtype=torch.float32)
            canvas_pred = grid_stack.clone()

            # transform to ImageNet distribution
            imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).to(args.device)
            imagenet_std = torch.tensor([0.229, 0.224, 0.225]).to(args.device)
            canvas_pred = (canvas_pred - imagenet_mean[:, None, None]) / imagenet_std[:, None, None]

            data_store_dic, support_name = PANICL.forward_meta_test_indice(canvas_pred, support_name)

            if f"{query_name_val}_{support_name_datastore}_{query_class_datastore}" not in datastore:
                datastore[f"{query_name_val}_{support_name_datastore}_{query_class_datastore}"] = data_store_dic
                support_name_dic[query_name_val] = support_name_datastore
            else:
                for keys in datastore[f"{query_name_val}_{support_name_datastore}_{query_class_datastore}"].keys():
                    datastore[f"{query_name_val}_{support_name_datastore}_{query_class_datastore}"][
                        keys].extend(data_store_dic[keys])

    print(f'datastore shape: ', len(datastore.keys()))
    print(f'datastore per samples: ', len(datastore[list(datastore)[1]]['105']))

    val_dataset = {
        'pascal': val_pascal_dataloader.DatasetPASCAL,
    }[args.dataset_type](args.base_dir, fold=args.fold, split=args.split, image_transform=image_transform,
                         mask_transform=mask_transform,
                         flipped_order=args.flip, purple=args.purple, random=args.random, cluster=args.cluster,
                         feature_name=args.feature_name, percentage=args.percentage, seed=args.seed,
                         arr=args.arr, cls_base=args.cls_base,
                         support_name_dic=support_name_dic, random_example=args.random_example)

    dataloaders['val'] = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    print('length of val dataset: ', len(val_dataset))

    examples_save_path = eg_save_path + f'/shot_{args.n_shot}'
    os.makedirs(examples_save_path, exist_ok=True)
    image_number = 0

    # Query branch.
    for i, data in enumerate(tqdm(dataloaders["val"])):
        (support_img, support_mask, query_img, query_mask, support_name, support_class, query_name, query_class,
         support_name_datastore, grid_stack) = (data['support_img'], data['support_mask'], data['query_img'],
                                                data['query_mask'], data['support_name'], data['support_class'],
                                                data['query_name'], data['query_class'], data['support_name_datastore'],
                                                data['grid_stack'])
        support_img = support_img.to(args.device, dtype=torch.float32)
        support_mask = support_mask.to(args.device, dtype=torch.float32)
        query_img = query_img.to(args.device, dtype=torch.float32)
        query_mask = query_mask.to(args.device, dtype=torch.float32)
        grid_stack = grid_stack.to(args.device, dtype=torch.float32)
        canvas_label, canvas_pred = PANICL.form_cavas(support_img, support_mask, query_img, query_mask, grid_stack)

        # convert to imagenet distribution.
        imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).to(args.device)
        imagenet_std = torch.tensor([0.229, 0.224, 0.225]).to(args.device)
        canvas_pred = (canvas_pred - imagenet_mean[:, None, None]) / imagenet_std[:, None, None]
        canvas_label = (canvas_label - imagenet_mean[:, None, None]) / imagenet_std[:, None, None]

        key_list = []
        assert len(support_name_datastore) == len(query_name) == query_class.size(0)
        for key in range(len(support_name_datastore)):
            key_list.append(f"{query_name[key]}_{support_name_datastore[key]}_{query_class[key]}")

        original_image_list, generated_result_list = _generate_result_by_knn_indice_weight_prob(args,
                                                                                                vqgan.to(args.device),
                                                                                                canvas_pred,
                                                                                                canvas_label,
                                                                                                args.arr,
                                                                                                datastore,
                                                                                                key_list,
                                                                                                args.k)

        for index in range(len(original_image_list)):
            if args.save_examples:
                # for qualitative comparison.
                Image.fromarray(np.uint8(generated_result_list[index])).save(
                    examples_save_path + f'/generated_image_{image_number}.png')
            original_image = round_image(original_image_list[index], [WHITE, BLACK])
            generated_result = round_image(generated_result_list[index], [WHITE, BLACK], t=args.t)
            current_metric = calculate_metric(args, original_image, generated_result, fg_color=WHITE, bg_color=BLACK)
            with open(os.path.join(examples_save_path, 'log.txt'), 'a') as log:
                log.write(str(image_number) + '\t' + str(current_metric) + '\n')
            image_number += 1
            for i, j in current_metric.items():
                eval_dict[i] += (j / len(val_dataset))

    print('eval_dict: ', eval_dict)
    with open(os.path.join(examples_save_path, 'log.txt'), 'a') as log:
        log.write('all\t' + str(eval_dict) + '\n')


if __name__ == '__main__':
    mp.set_start_method('spawn')
    args = get_args()

    args = args.parse_args()
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    test(args)
