import os
import json
import tqdm
import torch
import argparse
from eval import *
import numpy as np
from model import SSOD
import torch.distributed as dist
import torch.utils.data.distributed
from torch.utils.data import DataLoader
from cifar_data import get_train_val_dataset, get_cifar_ood_dataset


def main(opt):
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.device_ids
    
    dist.init_process_group(backend='nccl', init_method=opt.init_method, world_size=opt.n_gpus)

    batch_size = opt.batch_size
    device = torch.device('cuda', opt.local_rank if torch.cuda.is_available() else 'cpu')
    print('Using device:{}'.format(device))

    # load dataset
    _, val_set = get_train_val_dataset()
    ood_set = get_cifar_ood_dataset(ood_type=opt.ood_type)

    # prepare dataloader
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_set, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, sampler=val_sampler, num_workers=6)

    ood_sampler = torch.utils.data.distributed.DistributedSampler(ood_set, shuffle=True)
    ood_loader = DataLoader(ood_set, batch_size=batch_size, sampler=ood_sampler, num_workers=6)
        
    model = SSOD(depth=opt.depth, num_classes=opt.num_classes)
    
    # loading checkpoint on GPU 0
    if opt.local_rank == 0:
        try:
            model.load_state_dict(torch.load(opt.checkpoint, map_location='cpu'), strict=True)
        except:
            print('No Checkpoint, training from scratch...')

    model = torch.nn.parallel.DistributedDataParallel(model.to(device), device_ids=[opt.local_rank],
                                                      output_device=opt.local_rank, broadcast_buffers=False,
                                                      find_unused_parameters=True)

    # evaluation
    model.eval()
    val_cls_acc, val_loss = 0, 0
    if opt.local_rank == 0:
        # ID inference
        # with torch.no_grad():
        #     for x, y in tqdm.tqdm(val_loader):
        #         x = x.float().to(device)
        #         y = y.long().to(device)
        #         _, cls_logits, loss = model.module.loss(x, y, ood_weight=opt.ood_weight, get_feat=True, ood_loss=True)

        #         # record accuracy
        #         cls_acc = ACC(cls_logits, y)
        #         val_cls_acc += cls_acc
        #         val_loss += loss.item()

        # OOD inference
        id_ood_conf_msp, id_ood_conf_ssod, id_ood_label = list(), list(), list()
        with torch.no_grad():
            # ood loader
            ood_vis_conf_msp, ood_vis_conf_ssod = list(), list()
            for x in tqdm.tqdm(ood_loader):
                x = x.float().to(device)
                max_softmax, pred_label, rectified_p = model.module.ood_infer(x)
                if x.shape[0] > 1:
                    id_ood_conf_msp.extend(max_softmax.detach().squeeze().cpu().numpy().tolist())
                    id_ood_conf_ssod.extend(rectified_p.detach().squeeze().cpu().numpy().tolist())
                    id_ood_label.extend(np.zeros(max_softmax.shape[0]).tolist())
                else:
                    id_ood_conf_msp.append(max_softmax.detach().squeeze().cpu().item())
                    id_ood_conf_ssod.append(rectified_p.detach().squeeze().cpu().item())
                    id_ood_label.append(0)

                print('OOD Conf:', rectified_p.mean())
                ood_vis_conf_msp.append(max_softmax.mean().item())
                ood_vis_conf_ssod.append(rectified_p.mean().item())
            
            # id loader
            id_count = 0
            id_vis_conf_msp, id_vis_conf_ssod = list(), list()
            for x, _ in tqdm.tqdm(val_loader):
                x = x.float().to(device)
                max_softmax, pred_label, rectified_p = model.module.ood_infer(x)
                if x.shape[0] > 1:
                    id_ood_conf_msp.extend(max_softmax.detach().squeeze().cpu().numpy().tolist())
                    id_ood_conf_ssod.extend(rectified_p.detach().squeeze().cpu().numpy().tolist())
                    id_ood_label.extend(np.ones(max_softmax.shape[0]).tolist())
                else:
                    id_ood_conf_msp.append(max_softmax.detach().squeeze().cpu().item())
                    id_ood_conf_ssod.append(rectified_p.detach().squeeze().cpu().item())
                    id_ood_label.append(1)

                print('ID Conf:', rectified_p.mean())
                id_vis_conf_msp.append(max_softmax.mean().item())
                id_vis_conf_ssod.append(rectified_p.mean().item())

                id_count += 1
                if id_count >= len(ood_loader):
                    break

        assert len(id_ood_conf_msp) == len(id_ood_conf_ssod) == len(id_ood_label)

        FPR_msp = FPR(np.array(id_ood_conf_msp), np.array(id_ood_label), threshold=0.95)
        FPR_ssod = FPR(np.array(id_ood_conf_ssod), np.array(id_ood_label), threshold=0.95)

        AUROC_msp = AUROC(np.array(id_ood_conf_msp), np.array(id_ood_label))
        AUROC_ssod = AUROC(np.array(id_ood_conf_ssod), np.array(id_ood_label))

        val_loss = val_loss / len(val_loader)
        val_cls_acc = val_cls_acc / len(val_loader)

        print('Dataset : %s | Val Loss : %.4f | Val Cls Acc : %.4f | FPR95(MSP) : %.4f | FPR95(SSOD) : %.4f | AUROC(MSP) : %.4f | AUROC(SSOD) : %.4f'
            % (opt.ood_type, val_loss, val_cls_acc, FPR_msp, FPR_ssod, AUROC_msp, AUROC_ssod))
        
        save_dict = {'id_conf': id_vis_conf_msp, 'ood_conf': ood_vis_conf_msp}
        save_path = './conf_results/cifar'
        os.makedirs(save_path, exist_ok=True)
        save_name = '%s/MSP_%s.json' % (save_path, opt.ood_type)
        with open(save_name, 'w') as f:
            json.dump(save_dict, f)
        
        save_name = '%s/%s.json' % (save_path, opt.ood_type)
        save_dict = {'id_conf': id_vis_conf_ssod, 'ood_conf': ood_vis_conf_ssod}
        with open(save_name, 'w') as f:
            json.dump(save_dict, f)
        

if __name__ == '__main__':
    parser = argparse.ArgumentParser('SSOD Evaluation')
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--init_method', default='env://')

    parser.add_argument('--n_gpus', type=int, default=4)
    parser.add_argument('--device_ids', type=str, default='0,1,2,3')

    parser.add_argument('--depth', type=int, default=18)
    parser.add_argument('--num_classes', type=int, default=10)
    parser.add_argument('--ood_weight', type=float, default=0.1)
    parser.add_argument('--ood_type', type=str, default='SVHN')
    # parser.add_argument('--ood_type', type=str, default='iSUN')
    # parser.add_argument('--ood_type', type=str, default='Places')
    # parser.add_argument('--ood_type', type=str, default='Texture')
    # parser.add_argument('--ood_type', type=str, default='LSUN')
    parser.add_argument('--checkpoint', type=str, default='./saved_models/cifar_10/ssod_svhn/epoch_19_cls_0.9440_fpr95_ssod_0.0212_fpr95_msp_0.7780_auroc_ssod_0.9944.pth')
    # parser.add_argument('--checkpoint', type=str, default='./saved_models/cifar_10/ssod_svhn/epoch_9_cls_0.9230_fpr95_ssod_0.0288_fpr95_msp_0.6744_auroc_ssod_0.9938.pth')

    opt = parser.parse_args()
    if opt.local_rank == 0:
        print('opt:', opt)

    main(opt)

# if address already in use, you can use another random master_port
# python3 -m torch.distributed.launch --master_port 9998 --nproc_per_node=4 cifar_test.py --n_gpus=4