import os
import cv2
import json
import torch
import random
import logging
import argparse
import numpy as np
import math
import VADCLIP
from PIL import Image
from skimage import measure
from tabulate import tabulate
import torch.nn.functional as F
import torchvision.transforms as transforms
from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve
from dataset import VisaDataset, MVTecDataset
from prompt_generation import prompt_encoder
from torchvision import transforms
import time
import warnings

def plot_sample_cv2(names, imgs, scores_: dict, gts, binary_mask , score_map, path, save_folder=None):
    # get subplot number
    total_number = len(imgs)

    scores = scores_.copy()
    # normarlisze anomalies
    for k, v in scores.items():
        max_value = np.max(v)
        min_value = np.min(v)

        scores[k] = (scores[k] - min_value) / max_value * 255
        scores[k] = scores[k].astype(np.uint8)
    # draw gts
    mask_imgs = []
    for idx in range(total_number):
        gts_ = gts[idx]
        cv2.cvtColor( imgs[idx], cv2.COLOR_RGB2BGR,  imgs[idx])
        mask_imgs_ = imgs[idx].copy()
        mask_imgs_[gts_ > 0.5] = (0, 0, 255)
        mask_imgs.append(mask_imgs_)
    
    gt_mask = []
    for idx in range(total_number):
        gts_ = gts[idx]
        g_mask_imgs_ = gts[idx].copy()
        g_mask_imgs_[gts_ > 0.5] = 255
        g_mask_imgs_[gts_ <= 0.5] = 0
        gt_mask.append(g_mask_imgs_)
        
    # save imgs
    for idx in range(total_number):
        save_folder_f = save_folder + path[idx]
        cv2.imwrite(os.path.join(save_folder_f, f'{names[idx]}_ori.jpg'), imgs[idx])
        cv2.imwrite(os.path.join(save_folder_f, f'{names[idx]}_gt.jpg'), mask_imgs[idx])
        
        m_heat_map = cv2.applyColorMap(gt_mask[idx].astype(np.uint8), cv2.COLORMAP_JET)
        m_map = cv2.addWeighted(m_heat_map, 0.5, imgs[idx], 0.5, 0)
        cv2.imwrite(os.path.join(save_folder_f, f'{names[idx]}_mk.jpg'), m_map)
        b_heat_map = cv2.applyColorMap(binary_mask[idx].astype(np.uint8), cv2.COLORMAP_JET)
        b_map = cv2.addWeighted(b_heat_map, 0.5, imgs[idx], 0.5, 0)
        cv2.imwrite(os.path.join(save_folder_f, f'{names[idx]}_bmk.jpg'), b_map)
        
        for key in scores:
            heat_map = cv2.applyColorMap(scores[key][idx], cv2.COLORMAP_JET)
            visz_map = cv2.addWeighted(heat_map, 0.5, imgs[idx], 0.5, 0)
            cv2.imwrite(os.path.join(save_folder_f, f'{names[idx]}_{key}.jpg'),
                        visz_map)
            
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    #torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = False

def apply_ad_scoremap(image, scoremap, alpha=0.7):
    np_image = np.asarray(image, dtype=float)
    scoremap = (scoremap * 255).astype(np.uint8)
    scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)
    scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)
    #cam = np.float32(scoremap)/255 + np.float32(image)/255
    #cam = cam / np.max(cam)
    return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8), scoremap.astype(np.uint8)

def apply_ad_bmap(image, scoremap, alpha=0.7):
    np_image = np.asarray(image, dtype=float)
    scoremap = (scoremap * 255).astype(np.uint8)
    scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)
    scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)
    #cam = np.float32(scoremap)/255 + np.float32(image)/255
    #cam = cam / np.max(cam)
    return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)

def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3):
    # ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py
    binary_amaps = np.zeros_like(amaps, dtype=bool)
    min_th, max_th = amaps.min(), amaps.max()
    delta = (max_th - min_th) / max_step
    pros, fprs, ths = [], [], []
    for th in np.arange(min_th, max_th, delta):
        binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1
        pro = []
        for binary_amap, mask in zip(binary_amaps, masks):
            for region in measure.regionprops(measure.label(mask)):
                tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum()
                pro.append(tp_pixels / region.area)
        inverse_masks = 1 - masks
        fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
        fpr = fp_pixels / inverse_masks.sum()
        pros.append(np.array(pro).mean())
        fprs.append(fpr)
        ths.append(th)
    pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths)
    idxes = fprs < expect_fpr
    fprs = fprs[idxes]
    fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min())
    pro_auc = auc(fprs, pros[idxes])
    return pro_auc
def _convert_image_to_rgb(image):
    return image.convert("RGB")
def get_data_transforms(size, isize):
    mean_train = (0.48145466, 0.4578275, 0.40821073)
    std_train = (0.26862954, 0.26130258, 0.27577711)
    data_transforms = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.CenterCrop(isize),
        _convert_image_to_rgb,
        transforms.ToTensor(),
        #transforms.CenterCrop(args.input_size),
        transforms.Normalize(mean=mean_train,
                             std=std_train)])
    gt_transforms = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.CenterCrop(isize),
        transforms.ToTensor()])
    return data_transforms, gt_transforms

def weight_reset(m):
    if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
        m.reset_parameters()

def resize_tokens(x):
    B, N, C = x.shape
    x = x.view(B,int(math.sqrt(N)),int(math.sqrt(N)),C)
    return x
            
mean_train = [0.48145466, 0.4578275, 0.40821073]
std_train = [0.26862954, 0.26130258, 0.27577711]


def normalize(pred, max_value=None, min_value=None):
    if max_value is None or min_value is None:
        return (pred - pred.min()) / (pred.max() - pred.min())
    else:
        return (pred - min_value) / (max_value - min_value)
def denormalization(x):
    x = (((x.transpose(1, 2, 0) * std_train) + mean_train) * 255.).astype(np.uint8)
    return x

def specify_resolution(image_list, score_list, mask_list, binary_mask , score_map, resolution: tuple=(400,400)):
    resize_image = []
    resize_score = []
    resize_mask = []
    resize_binary_mask = []
    resize_score_map = []
    for image, score, mask, binary, smap in zip(image_list, score_list, mask_list, binary_mask , score_map):
        image = cv2.resize(image, (resolution[0], resolution[1]), interpolation=cv2.INTER_CUBIC)
        score = cv2.resize(score, (resolution[0], resolution[1]), interpolation=cv2.INTER_CUBIC)
        mask = cv2.resize(mask, (resolution[0], resolution[1]), interpolation=cv2.INTER_NEAREST)
        binary = cv2.resize(binary, (resolution[0], resolution[1]), interpolation=cv2.INTER_CUBIC)
        smap = cv2.resize(smap, (resolution[0], resolution[1]), interpolation=cv2.INTER_CUBIC)
        resize_image += [image]
        resize_score += [score]
        resize_mask += [mask]
        resize_binary_mask += [binary]
        resize_score_map += [smap]

    return resize_image, resize_score, resize_mask,  resize_binary_mask  , resize_score_map

class LinearLayer(nn.Module):
    def __init__(self, dim_in, dim_out, k, model_name, model):
        super(LinearLayer, self).__init__()
        if 'ViT' in model_name:
            self.fc = nn.ModuleList([nn.Linear(dim_in, dim_out) for i in range(k)])
        else:
            self.fc = nn.ModuleList([nn.Linear(dim_in * 2 ** (i + 2), dim_out) for i in range(k)])
        self.ln = model.visual.ln_post
        self.proj = model.visual.proj

    def forward(self, tokens):
        for i in range(len(tokens)):
            if len(tokens[i].shape) == 3:
                #tokens[i] = self.fc[i](tokens[i][:, 1:, :])
                tokens[i] = self.ln(tokens[i][:, 1:, :]) @ self.proj
            else:
                assert 1==2,"Not completed!"
                B, C, H, W = tokens[i].shape
                tokens[i] = self.fc[i](tokens[i].view(B, C, -1).permute(0, 2, 1).contiguous())
        return tokens
    
def test(args):
    img_size = args.image_size
    features_list = args.features_list
    few_shot_features = args.few_shot_features
    dataset_dir = args.data_path
    save_path = args.save_path
    dataset_name = args.dataset
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    gpu_device = f"cuda:"+str(args.gpu_id)
    device = gpu_device if torch.cuda.is_available() else "cpu"
    txt_path = os.path.join(save_path, 'log.txt')
    
    model, _, preprocess = open_clip.create_model_and_transforms(args.model, img_size, pretrained=args.pretrained)
    data_transform, gt_transform = get_data_transforms(img_size, img_size)
    model.to(device)
    tokenizer = open_clip.get_tokenizer(args.model)
    # logger
    root_logger = logging.getLogger()
    for handler in root_logger.handlers[:]:
        root_logger.removeHandler(handler)
    root_logger.setLevel(logging.WARNING)
    logger = logging.getLogger('test')
    formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
                                  datefmt='%y-%m-%d %H:%M:%S')
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler(txt_path, mode='a')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    # record parameters
    for arg in vars(args):
        if args.mode == 'zero_shot' and (arg == 'k_shot' or arg == 'few_shot_features'):
            continue
        logger.info(f'{arg}: {getattr(args, arg)}')
    # seg
    with open(args.config_path, 'r') as f:
        model_configs = json.load(f)
    linearlayer = LinearLayer(model_configs['vision_cfg']['width'], model_configs['embed_dim'],
                              len(features_list), args.model, model).to(device)
    # dataset
    transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.CenterCrop(img_size),
            transforms.ToTensor()
        ])
    if dataset_name == 'mvtec':
        test_data = MVTecDataset(root=dataset_dir, transform=data_transform, target_transform=transform,
                                 aug_rate=-1, mode='test')
    else:
        test_data = VisaDataset(root=dataset_dir, transform=data_transform, target_transform=transform, mode='test')
    test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
    obj_list = test_data.get_cls_names()
    
    if args.mode == 'few_shot':
        image_feature_gallery = {}
        for item in obj_list:
            if dataset_name == 'mvtec':
                normal_data = MVTecDataset(root=dataset_dir, transform=data_transform, target_transform=transform,
                                    aug_rate=-1, mode='train', k_shot=args.k_shot, 
                                    save_dir='',obj_name=item )
            elif dataset_name == 'visa':
                normal_data = VisaDataset(root=dataset_dir, transform=data_transform, target_transform=transform, 
                                        mode='train', k_shot=args.k_shot,
                                        save_dir='',obj_name=item )
            fewshot_dataloader = torch.utils.data.DataLoader(normal_data, batch_size=1, shuffle=False)
            image_features_list = []
            for items in fewshot_dataloader:
                image = items['img'].to(device)
                image_features, _ ,_ = model.encode_image(image, features_list)
                image_features /= image_features.norm(dim=-1, keepdim=True)
                image_features_list.append(image_features)
            image_feature_gallery[item] = image_features_list           
            
    results = {}
    results['cls_names'] = []
    results['imgs_masks'] = []
    results['anomaly_maps'] = []
    results['gt_sp'] = []
    results['pr_sp'] = []
    results['fs_sp'] = []
    results['img'] = []
    results['name'] = []
    results['score'] = []
    results['path'] = []
    results['binary_mask'] = []
    results['score_map'] = []
    time_list = []
    for items in test_dataloader:
        image = items['img']
        image = image.to(device)
        cls_name = items['cls_name']
        results['cls_names'].append(cls_name[0])
        gt_mask = items['img_mask']
        gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0     ##########
        results['imgs_masks'].append(gt_mask)  # px
        results['gt_sp'].append(items['anomaly'].item())
        results['img'].append(denormalization(np.array(image.reshape(3,240,240).cpu()))) 
        name = items['name'][0].split('/')
        results['name'].append(name[0] + '-' + name[2] + '-' + name[-1].split('.')[0]) 
        results['path'].append(name[0]) 
        
        image_features, patch_tokens , W_t = model.encode_image(image, features_list)  
        patch_tokens_ln = linearlayer(patch_tokens) 
        text_prompts, text_prompts_list, text_prompts_local = prompt_encoder(model, cls_name, tokenizer, device, W_t, dataset_name, args.gpt)
       
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features = []
        text_features_local = []
        for cls in cls_name:
            text_features.append(text_prompts[cls])
            text_features_local.append(text_prompts_local[cls])
        text_features = torch.stack(text_features, dim=0) 
        text_features_local = torch.stack(text_features_local, dim=0).squeeze() 
        
        start_time = time.time()
        anomaly_maps = []
        for layer in range(len(patch_tokens_ln)):
            tokens = resize_tokens(patch_tokens_ln[layer])
            tokens /= tokens.norm(dim=-1, keepdim=True)
            anomaly_map = (tokens @ text_features)
            anomaly_map_2 = (tokens @ text_features_local)
            anomaly_map = 2. /(1./anomaly_map + 1./anomaly_map_2)
            anomaly_map = torch.softmax(anomaly_map, dim=-1)[:,:,:,1].unsqueeze(1)
            anomaly_map = torch.nn.functional.pad(anomaly_map,(1,1,1,1),'replicate')
            anomaly_map = torch.nn.functional.avg_pool2d(anomaly_map, 3, stride=1, padding=0,count_include_pad=False)
            anomaly_map = F.interpolate(anomaly_map,
                                        size=img_size, mode='bilinear', align_corners=True)
            anomaly_maps.append(anomaly_map[0])
            anomaly_map = anomaly_map[0].cpu().numpy()#torch.sum(torch.stack(anomaly_maps), axis=0)
        time_list.append(time.time()-start_time)
        anomaly_map = anomaly_map  #(1, 240, 240)
            
        if args.mode == 'few_shot':
            image_features_list = image_feature_gallery[cls_name[0]]    
            similarities = [F.cosine_similarity(image_features, img_feature,dim=-1).mean() for img_feature in image_features_list]
            max_similarity, max_index = max((sim, i) for i, sim in enumerate(similarities))
            
            similarities = F.cosine_similarity(image_features, image_features_list[max_index], dim=-1)
            distance_map = 1 - similarities  
            if len(distance_map.shape)==2:
                distance_map = distance_map.unsqueeze(0)
            distance_map = distance_map[:,:,1:]
            distance_map /= distance_map.norm(dim=-1, keepdim=True)
            distance_map = distance_map.reshape(15,15)
            anomaly_map_i = F.interpolate(distance_map.unsqueeze(0).unsqueeze(0), size=(args.image_size, args.image_size), mode='bilinear', align_corners=False)
            am_np = anomaly_map_i.squeeze(0).cpu().detach().numpy()  #torch.Size([1，240, 240])
            
            mm_anomaly_map = 2. / (1. / anomaly_map + 1. / am_np)
            anomaly_map = mm_anomaly_map
            results['anomaly_maps'].append(anomaly_map)
            img_scores = mm_anomaly_map.reshape(1, -1).max(axis=1) 

            text_probs = (image_features @ text_features[0]).softmax(dim=-1)
            results['pr_sp'].append(text_probs.mean(dim=0)[0,1].cpu().item())  
            
        else:
            text_probs = (image_features @ text_features[0]).softmax(dim=-1)
            results['anomaly_maps'].append(anomaly_map)
            img_scores = anomaly_map.reshape(1, -1).max(axis=1) 
            results['pr_sp'].append(text_probs.mean(dim=0)[0,1].cpu().item() )   

        results['score'].append(normalize(anomaly_map))
        
        # visualization
        path = items['img_path']
        cls = path[0].split('/')[-2]
        filename = path[0].split('/')[-1]
        vis = cv2.cvtColor(cv2.resize(cv2.imread(path[0]), (img_size, img_size)), cv2.COLOR_BGR2RGB)  # RGB
        size = anomaly_map[0].shape
        mask = anomaly_map[0]
        mask = normalize(mask)
        p, r, th = precision_recall_curve(gt_mask.ravel(), mask.ravel())
        f1_score = (2 * p * r) / (p + r)
        opt_th = th[np.argmax(f1_score)]
        binary_mask = np.copy(mask)
        binary_mask[binary_mask>=opt_th]=1
        binary_mask[binary_mask<opt_th]=0
        crop_vis = apply_ad_bmap(vis, mask*binary_mask)
        vis, score_map = apply_ad_scoremap(vis, mask)
        score_map = cv2.cvtColor(score_map, cv2.COLOR_RGB2BGR)
        results['binary_mask'].append( binary_mask*255 )  
        results['score_map'].append( score_map)
    
    
    # metrics
    table_ls = []
    auroc_sp_ls = []
    auroc_px_ls = []
    f1_sp_ls = []
    f1_px_ls = []
    aupro_ls = []
    ap_sp_ls = []
    ap_px_ls = []
    for obj in obj_list:
        table = []
        gt_px = []
        pr_px = []
        gt_sp = []
        pr_sp = []
        pr_sp_tmp = []
        img = []
        name = []
        score = []
        path = []
        binary_mask = []
        score_map = []
        table.append(obj)
        for idxes in range(len(results['cls_names'])):
            if results['cls_names'][idxes] == obj:
                gt_px.append(results['imgs_masks'][idxes].squeeze(1).numpy())
                pr_px.append(results['anomaly_maps'][idxes])
                gt_sp.append(results['gt_sp'][idxes])
                pr_sp.append(results['pr_sp'][idxes])
                img.append(results['img'][idxes])
                name.append(results['name'][idxes])
                score.append(results['score'][idxes])
                path.append(results['path'][idxes])
                binary_mask.append(results['binary_mask'][idxes])
                score_map.append(results['score_map'][idxes])
        score = np.array(score).squeeze(1)
        gt_px = np.array(gt_px)
        gt_sp = np.array(gt_sp)
        pr_px = np.array(pr_px)
        pr_sp = np.array(pr_sp)
        auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel())
        auroc_sp = roc_auc_score(gt_sp, pr_sp)
        ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel())
        precisions, recalls, thresholds = precision_recall_curve(gt_sp, pr_sp)
        f1_scores = (2 * precisions * recalls) / (precisions + recalls)
        f1_sp = np.max(f1_scores[np.isfinite(f1_scores)])
        aupr_sp = auc(recalls, precisions)
        precisions, recalls, thresholds = precision_recall_curve(gt_px.ravel(), pr_px.ravel())
        f1_scores = (2 * precisions * recalls) / (precisions + recalls)
        f1_px = np.max(f1_scores[np.isfinite(f1_scores)])
        if len(gt_px.shape) == 4:
            gt_px = gt_px.squeeze(1)
        if len(pr_px.shape) == 4:
            pr_px = pr_px.squeeze(1)
        aupro = cal_pro_score(gt_px, pr_px)

        table.append(str(np.round(auroc_px * 100, decimals=1)))
        table.append(str(np.round(f1_px * 100, decimals=1)))
        table.append(str(np.round(ap_px * 100, decimals=1)))
        table.append(str(np.round(aupro * 100, decimals=1)))
        table.append(str(np.round(auroc_sp * 100, decimals=1)))
        table.append(str(np.round(f1_sp * 100, decimals=1)))
        table.append(str(np.round(aupr_sp * 100, decimals=1)))

        table_ls.append(table)
        auroc_sp_ls.append(auroc_sp)
        auroc_px_ls.append(auroc_px)
        f1_sp_ls.append(f1_sp)
        f1_px_ls.append(f1_px)
        aupro_ls.append(aupro)
        ap_sp_ls.append(aupr_sp)
        ap_px_ls.append(ap_px)
        
        test_imgs = img
        scores = list(score) 
        gt_mask_list = list(gt_px)
        img_dir = args.data_save_path + dataset_name + "/"
        names = name
 
        test_imgs, scores, gt_mask_list, binary_mask , score_map = specify_resolution(test_imgs, scores, gt_mask_list, binary_mask , score_map, resolution=(240, 240))  
        plot_sample_cv2(names, test_imgs, {'VAD': scores}, gt_mask_list, binary_mask , score_map , path, save_folder=img_dir)
        

    # logger
    table_ls.append(['mean', str(np.round(np.mean(auroc_px_ls) * 100, decimals=1)),
                     str(np.round(np.mean(f1_px_ls) * 100, decimals=1)), str(np.round(np.mean(ap_px_ls) * 100, decimals=1)),
                     str(np.round(np.mean(aupro_ls) * 100, decimals=1)), str(np.round(np.mean(auroc_sp_ls) * 100, decimals=1)),
                     str(np.round(np.mean(f1_sp_ls) * 100, decimals=1)), str(np.round(np.mean(ap_sp_ls) * 100, decimals=1))])
    results = tabulate(table_ls, headers=['objects', 'auroc_px', 'f1_px', 'ap_px', 'aupro', 'auroc_sp',
                                          'f1_sp', 'pr_sp'], tablefmt="pipe")
    logger.info("\n%s", results)


if __name__ == '__main__':
    warnings.filterwarnings('ignore')
    parser = argparse.ArgumentParser("LEADVAD", add_help=True)
    # paths
    parser.add_argument("--data_path", type=str, default="./data/visa", help="path to test dataset")
    parser.add_argument("--save_path", type=str, default='./results/test', help='path to save results')
    parser.add_argument("--config_path", type=str, default='./open_clip/model_configs/ViT-B-16-plus-240.json', help="model configs")
    parser.add_argument("--data_save_path", type=str, default="./visualization/", help="visualization")
    # model
    parser.add_argument("--dataset", type=str, default='mvtec', help="test dataset")
    parser.add_argument("--model", type=str, default="ViT-B-16-plus-240", help="model used")
    parser.add_argument("--pretrained", type=str, default="laion400m_e32", help="pretrained weight used")
    parser.add_argument("--features_list", type=int, nargs="+", default=[3, 6, 9, 12], help="features used")   #[3, 6, 9, 12]
    parser.add_argument("--few_shot_features", type=int, nargs="+", default=[3, 6, 9, 12], help="features used for few shot")
    parser.add_argument("--image_size", type=int, default=240, help="image size")
    parser.add_argument("--mode", type=str, default="zero_shot", help="zero shot or few shot")
    # few shot
    parser.add_argument("--k_shot", type=int, default=10, help="e.g., 10-shot, 5-shot, 1-shot")
    parser.add_argument("--seed", type=int, default=111, help="random seed")
    parser.add_argument("--gpt", type=bool, default=False, help="add gpt prompts")
    parser.add_argument("--gpu_id", type=int, default=0)
    
    args = parser.parse_args()
    os.environ['CURL_CA_BUNDLE'] = ''
    # os.environ['CUDA_VISIBLE_DEVICES'] = f"{args.gpu_id}"
    os.environ['CUDA_VISIBLE_DEVICES'] ="0,1,2,3,4,5,6,7"
    
    setup_seed(args.seed)
    test(args)
