import os.path
from tqdm import tqdm
from trainer import val_pascal_dataloader
from trainer import train_fewshot_pascal_dataloader
from evaluate.reasoning_dataloader import *
import torchvision.transforms as T
from evaluate.mae_utils import *
import argparse
from pathlib import Path
from evaluate.segmentation_utils import *
from PIL import Image
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
import time
import random
import wandb
from SegGPT.seggpt_engine import inference_image
from SegGPT import models_seggpt


def get_args():
    parser = argparse.ArgumentParser('Multi-example ICL baseline segmentation for SegGPT', add_help=False)
    parser.add_argument('--mae_model', default='mae_vit_large_patch16', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--seggpt-model', type=str, help='dir to ckpt',
                        default='seggpt_vit_large_patch16_input896x448')
    parser.add_argument('--seg_type', type=str, help='embedding for segmentation types',
                        choices=['instance', 'semantic'], default='semantic')
    parser.add_argument("--mode", type=str, default='no_vp',
                        choices=['panicl', 'icl'],
                        help="mode of adding vp on img.")
    parser.add_argument("--anchor-mode", type=str, default='query',
                        choices=['itself', 'query', 'random', 'seq'],
                        help="The anchor mode for building prompt pool.")
    parser.add_argument('--output_dir', default=f'./output_samples')
    parser.add_argument('--device', default='cuda:0',
                        help='device to use for training / testing')
    parser.add_argument('--base_dir', default='./pascal-5i', help='pascal base dir')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--t', default=[0, 0, 0], type=float, nargs='+')
    parser.add_argument('--task', default='segmentation', choices=['segmentation', 'detection'])
    parser.add_argument('--ckpt', default='./SegGPT/seggpt_vit_large.pth', help='model checkpoint')
    parser.add_argument('--dataset_type', default='pascal')
    parser.add_argument('--fold', default=0, type=int)
    parser.add_argument('--split', default='trn', type=str)
    parser.add_argument('--purple', default=0, type=int)
    parser.add_argument('--flip', default=0, type=int)
    parser.add_argument('--feature_name', default='features_vit-laion2b_feature-level_val', type=str)
    parser.add_argument('--percentage', default='', type=str)
    parser.add_argument('--cluster', action='store_true')
    parser.add_argument('--random', action='store_true')
    parser.add_argument("--random-example", action='store_true',
                        help="whether using random/in-context pair retriever.")
    parser.add_argument('--save-examples', action='store_true', help='whether save the visual examples')
    parser.add_argument('--cls-base', action='store_true')
    parser.add_argument("--k", type=int, default=5,
                        help="number of knn.")

    # testing settings for PANICL
    parser.add_argument("--batch-size", type=int, default=1,
                        help="Number of images sent to the network in one step.")
    parser.add_argument("--arr", type=str, default='a1',
                        help="the setting of arrangements of canvas")
    parser.add_argument('--save-model-path',
                        help='model checkpoint')
    parser.add_argument("--n-shot", type=int, default=4,
                        help="Number of images for fsl.")
    parser.add_argument("--temp", type=float, default=1.0,
                        help="Temperature scaling factor")
    parser.add_argument("--alpha", type=float, default=1.0,
                        help="Balanced factor")

    return parser


def train(args):
    padding = 1
    image_transform = T.Compose(
        [T.Resize((224 // 2 - padding, 224 // 2 - padding), 3),
         T.ToTensor()])
    mask_transform = T.Compose(
        [T.Resize((224 // 2 - padding, 224 // 2 - padding), 3),
         T.ToTensor()])

    train_dataset = {
        'pascal': train_fewshot_pascal_dataloader.DatasetPASCAL,
    }[args.dataset_type](args.base_dir, fold=args.fold, split=args.split,
                         image_transform=image_transform,
                         mask_transform=mask_transform,
                         flipped_order=args.flip, purple=args.purple, random=args.random, cluster=args.cluster,
                         feature_name=args.feature_name, percentage=args.percentage, seed=args.seed, arr=args.arr,
                         n_shot=args.n_shot, cls_base=args.cls_base, random_example=args.random_example,
                         anchor_mode=args.anchor_mode)


    dataloaders = {}

    dataloaders['train'] = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)

    print('train datalaoder: ', len(dataloaders['train']))

    print("load data over")


    # SegGPT model
    seg_model = prepare_seggpt_model(models_seggpt, args.ckpt, args.seggpt_model, args.seg_type).to(args.device)

    # Dataset initialization
    img_path = os.path.join(args.base_dir, 'VOC2012/JPEGImages/')
    ann_path = os.path.join(args.base_dir, 'VOC2012/SegmentationClassAug/')

    eg_save_path = f'{args.output_dir}/seg_gpt_output_examples/{args.task}_fold_{args.fold}_shot_{args.n_shot}'

    print(f'We adopt the arrangement of {args.arr}.')
    print("*" * 50)

    support_name_dic = {}

    for simidx in [0]:
        eval_dict = {'iou': 0, 'color_blind_iou': 0, 'accuracy': 0}
        val_dataset = {
            'pascal': val_pascal_dataloader.DatasetPASCAL,
        }[args.dataset_type](args.base_dir, fold=args.fold, split=args.split, image_transform=image_transform,
                                mask_transform=mask_transform,
                                flipped_order=args.flip, purple=args.purple, random=args.random, cluster=args.cluster,
                                feature_name=args.feature_name, percentage=args.percentage, seed=args.seed, mode=args.mode,
                                arr=args.arr, cls_base=args.cls_base, selected_label=args.selected_label, simidx=simidx,
                                support_name_dic=support_name_dic, random_prompt=args.random_prompt)

        dataloaders['val'] = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
        print('length of val dataset: ', len(val_dataset))

        # print(datastore[0]['key'].shape)
        if args.mode == 'icl':
            examples_save_path = eg_save_path + f'/icl_{simidx}/'
        elif args.mode == 'panicl':
            examples_save_path = eg_save_path + f'/panicl_{simidx}/'
        print('examples_save_path: ', examples_save_path)
        os.makedirs(examples_save_path, exist_ok=True)
        image_number = 0
        # Formal test
        for i, data in enumerate(tqdm(dataloaders["val"])):
            start_time = time.time()
            len_dataloader = len(dataloaders["val"])
            support_names = [tup[0] for tup in data['support_names']]
            query_name, query_class, support_class = \
                (data['query_name'], data['query_class'], data['support_class'])

            query_image_path, support_image_path, query_mask, support_masks = add_image_path(img_path,
                                                                                                ann_path,
                                                                                                query_name,
                                                                                                support_names,
                                                                                                query_class,
                                                                                                support_class)
            if args.mode == 'icl':
                output_image_mask, _, _, _ = inference_image(seg_model, args.device,
                                                                    query_image_path, support_image_path, support_masks, panicl=False)
            elif args.mode == 'panicl':
                output_image_mask, _, _, _ = inference_image(seg_model, args.device,
                                                                        query_image_path, support_image_path,
                                                                        support_masks, panicl=True)

            if isinstance(query_mask, Image.Image):
                image_mask = query_mask.convert("RGB")
            else:
                image_mask = Image.open(query_mask).convert("RGB")

            # 4. Evaluate prediction
            image_mask = np.uint8(np.array(image_mask))
            image_mask = round_image(image_mask, [WHITE, BLACK])
            output_image_mask = round_image(output_image_mask, [WHITE, BLACK], t=args.t)

            current_metric = new_calculate_metric(image_mask, output_image_mask)

            if args.save_examples:
                Image.fromarray(output_image_mask.cpu().numpy().astype(np.uint8)).save(
                    examples_save_path + f'/{image_number}_{simidx}_{query_name[0]}_{float(current_metric["iou"] * 100):.2f}.png')

            with open(os.path.join(examples_save_path, 'log.txt'), 'a') as log:
                log.write(str(image_number) + '\t' + str(current_metric) + '\n')
            image_number += 1

            for i, j in current_metric.items():
                eval_dict[i] += (j / len(val_dataset))

            end_time = time.time()
            batch_time = end_time - start_time

        print('val metric: {}'.format(eval_dict))
        with open(os.path.join(examples_save_path, 'log.txt'), 'a') as log:
            log.write('all\t' + str(eval_dict) + '\n')


if __name__ == '__main__':
    mp.set_start_method('spawn')
    args = get_args()

    args = args.parse_args()
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    train(args)
