import os
import tqdm
import json
import torch
import argparse
from eval import *
import numpy as np
from model import SSOD
import torch.distributed as dist
import torch.utils.data.distributed
from torch.utils.data import DataLoader
from data import get_train_val_dataset, get_imagenet_ood_dataset


def main(opt):
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.device_ids
    
    dist.init_process_group(backend='nccl', init_method=opt.init_method, world_size=opt.n_gpus)

    batch_size = opt.batch_size
    device = torch.device('cuda', opt.local_rank if torch.cuda.is_available() else 'cpu')
    print('Using device:{}'.format(device))

    # load dataset
    train_set, val_set = get_train_val_dataset()
    ood_set = get_imagenet_ood_dataset(ood_type=opt.ood_type)

    # prepare dataloader
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_set, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, sampler=val_sampler, num_workers=6)

    ood_sampler = torch.utils.data.distributed.DistributedSampler(ood_set, shuffle=True)
    ood_loader = DataLoader(ood_set, batch_size=batch_size, sampler=ood_sampler, num_workers=6)
        
    model = SSOD(num_classes=opt.num_classes)
    
    # loading checkpoint on GPU 0
    if opt.local_rank == 0:
        try:
            model.load_state_dict(torch.load(opt.checkpoint, map_location='cpu'), strict=True)
        except:
            print('No Checkpoint, training from scratch...')

    model = torch.nn.parallel.DistributedDataParallel(model.to(device), device_ids=[opt.local_rank],
                                                      output_device=opt.local_rank, broadcast_buffers=False,
                                                      find_unused_parameters=True)

    # evaluation
    model.eval()
    id_conf_collects, ood_conf_collects = list(), list()
    val_cls_acc, val_loss = 0, 0
    if opt.local_rank == 0:
        # ID inference
        # with torch.no_grad():
        #     for x, y in tqdm.tqdm(val_loader):
        #         x = x.float().to(device)
        #         y = y.long().to(device)
        #         _, cls_logits, loss = model.module.loss(x, y, ood_weight=opt.ood_weight, get_feat=True, ood_loss=True)

        #         # record accuracy
        #         cls_acc = ACC(cls_logits, y)
        #         val_cls_acc += cls_acc
        #         val_loss += loss.item()

        # OOD inference
        id_ood_conf_msp, id_ood_conf_BayesAug, id_ood_label = list(), list(), list()
        with torch.no_grad():
            # ood loader
            for x in tqdm.tqdm(ood_loader):
                x = x.float().to(device)
                max_softmax, pred_label, rectified_p = model.module.ood_infer(x)
                if x.shape[0] > 1:
                    id_ood_conf_msp.extend(max_softmax.detach().squeeze().cpu().numpy().tolist())
                    id_ood_conf_BayesAug.extend(rectified_p.detach().squeeze().cpu().numpy().tolist())
                    id_ood_label.extend(np.zeros(max_softmax.shape[0]).tolist())
                else:
                    id_ood_conf_msp.append(max_softmax.detach().squeeze().cpu().item())
                    id_ood_conf_BayesAug.append(rectified_p.detach().squeeze().cpu().item())
                    id_ood_label.append(0)

                print('OOD Conf:', rectified_p.mean())
                # ood_conf_collects.append(rectified_p.mean().item())
                ood_conf_collects.append(max_softmax.mean().item())
            
            # id loader
            id_count = 0
            for x, _ in tqdm.tqdm(val_loader):
                x = x.float().to(device)
                max_softmax, pred_label, rectified_p = model.module.ood_infer(x)
                if x.shape[0] > 1:
                    id_ood_conf_msp.extend(max_softmax.detach().squeeze().cpu().numpy().tolist())
                    id_ood_conf_BayesAug.extend(rectified_p.detach().squeeze().cpu().numpy().tolist())
                    id_ood_label.extend(np.ones(max_softmax.shape[0]).tolist())
                else:
                    id_ood_conf_msp.append(max_softmax.detach().squeeze().cpu().item())
                    id_ood_conf_BayesAug.append(rectified_p.detach().squeeze().cpu().item())
                    id_ood_label.append(1)

                id_count += 1
                if id_count >= len(ood_loader):
                    break

                print('ID Conf:', rectified_p.mean())
                # id_conf_collects.append(rectified_p.mean().item())
                id_conf_collects.append(max_softmax.mean().item())

        assert len(id_ood_conf_msp) == len(id_ood_conf_BayesAug) == len(id_ood_label)

        FPR_msp = FPR(np.array(id_ood_conf_msp), np.array(id_ood_label), threshold=0.95)
        FPR_BayesAug = FPR(np.array(id_ood_conf_BayesAug), np.array(id_ood_label), threshold=0.95)

        AUROC_msp = AUROC(np.array(id_ood_conf_msp), np.array(id_ood_label))
        AUROC_BayesAug = AUROC(np.array(id_ood_conf_BayesAug), np.array(id_ood_label))

        val_loss = val_loss / len(val_loader)
        val_cls_acc = val_cls_acc / len(val_loader)

        print('Dataset : %s | Val Loss : %.4f | Val Cls Acc : %.4f | FPR95(MSP) : %.4f | FPR95(SSOD) : %.4f | AUROC(MSP) : %.4f | AUROC(SSOD) : %.4f'
            % (opt.ood_type, val_loss, val_cls_acc, FPR_msp, FPR_BayesAug, AUROC_msp, AUROC_BayesAug))
        
        print(len(id_conf_collects), len(ood_conf_collects))

        save_dict = {'id_conf': id_conf_collects, 'ood_conf': ood_conf_collects}

        save_path = './conf_results'
        os.makedirs(save_path, exist_ok=True)
        # save_name = '%s/%s.json' % (save_path, opt.ood_type)
        save_name = '%s/MSP_%s.json' % (save_path, opt.ood_type)
        with open(save_name, 'w') as f:
            json.dump(save_dict, f)
        

if __name__ == '__main__':
    parser = argparse.ArgumentParser('SSOD Evaluation')
    parser.add_argument('--batch_size', type=int, default=3)
    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--init_method', default='env://')

    # parser.add_argument('--n_gpus', type=int, default=7)
    # parser.add_argument('--device_ids', type=str, default='1,2,3,4,5,6,7')

    parser.add_argument('--n_gpus', type=int, default=4)
    parser.add_argument('--device_ids', type=str, default='0,1,2,3')

    parser.add_argument('--depth', type=int, default=50)
    parser.add_argument('--num_classes', type=int, default=1000)
    parser.add_argument('--ood_weight', type=float, default=0.1)
    # parser.add_argument('--ood_type', type=str, default='iNaturalist')
    # parser.add_argument('--ood_type', type=str, default='SUN')
    # parser.add_argument('--ood_type', type=str, default='Places')
    # parser.add_argument('--ood_type', type=str, default='Texture')
    parser.add_argument('--ood_type', type=str, default='imagenet-o')
    # parser.add_argument('--ood_type', type=str, default='openimage-o')
    # parser.add_argument('--checkpoint', type=str, default='./saved_models/ssod_sun/epoch_1_cls_0.7529_fpr95_ssod_0.2940_fpr95_msp_0.6840_auroc_ssod_0.9316_auroc_msp_0.8158.pth')
    parser.add_argument('--checkpoint', type=str, default='./saved_models/ssod_sun/epoch_2_cls_0.7536_fpr95_ssod_0.2756_fpr95_msp_0.6956_auroc_ssod_0.9362_auroc_msp_0.8113.pth')
    # parser.add_argument('--checkpoint', type=str, default='./saved_models/ssod_sun/epoch_0_cls_0.7564_fpr95_ssod_0.2992_fpr95_msp_0.6892_auroc_ssod_0.9296_auroc_msp_0.8129.pth')
    # parser.add_argument('--checkpoint', type=str, default='./saved_models/ssod_texture/epoch_0_cls_0.7574_fpr95_ssod_0.4532_fpr95_msp_0.6716_auroc_ssod_0.8702_auroc_msp_0.7942.pth')

    opt = parser.parse_args()
    if opt.local_rank == 0:
        print('opt:', opt)

    main(opt)

# if address already in use, you can use another random master_port
# python3 -m torch.distributed.launch --master_port 9998 --nproc_per_node=4 test.py --n_gpus=4