import os
import gc
import tqdm
import torch
import shutil
import argparse
from eval import *
import numpy as np
import torch.nn as nn
from model import SSOD
from torch import optim
import torch.distributed as dist
import torch.utils.data.distributed
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from cifar_data import get_train_val_dataset, get_cifar_ood_dataset


def main(opt):
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.device_ids

    if opt.local_rank == 0 and opt.build_tensorboard:
        shutil.rmtree(opt.logdir, True)
        writer = SummaryWriter(logdir=opt.logdir)
        opt.build_tensorboard = False
    
    dist.init_process_group(backend='nccl', init_method=opt.init_method, world_size=opt.n_gpus)

    batch_size = opt.batch_size
    device = torch.device('cuda', opt.local_rank if torch.cuda.is_available() else 'cpu')
    print('Using device:{}'.format(device))

    # load dataset
    train_set, val_set = get_train_val_dataset()
    ood_set = get_cifar_ood_dataset(ood_type=opt.ood_type)

    # prepare dataloader
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
    train_loader = DataLoader(train_set, batch_size=batch_size, sampler=train_sampler, num_workers=24)

    val_sampler = torch.utils.data.distributed.DistributedSampler(val_set, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, sampler=val_sampler, num_workers=12)

    ood_sampler = torch.utils.data.distributed.DistributedSampler(ood_set, shuffle=True)
    ood_loader = DataLoader(ood_set, batch_size=batch_size, sampler=ood_sampler, num_workers=12)
        
    model = SSOD(depth=opt.depth, num_classes=opt.num_classes)
    
    # loading checkpoint on GPU 0
    if opt.local_rank == 0:
        try:
            model.load_state_dict(torch.load(opt.checkpoint, map_location='cpu'), strict=False)
        except:
            print('No Checkpoint, training from scratch...')

    model = torch.nn.parallel.DistributedDataParallel(model.to(device), device_ids=[opt.local_rank],
                                                      output_device=opt.local_rank, broadcast_buffers=False,
                                                      find_unused_parameters=True)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
    
    for epoch in range(opt.epoch):
        train_loader.sampler.set_epoch(epoch)

        # only tqdm in rank 0
        if opt.local_rank == 0:
            data_loader = tqdm.tqdm(train_loader)
        else:
            data_loader = train_loader
        
        train_loss, val_loss = 0, 0
        train_cls_acc, val_cls_acc = 0, 0

        model.train()
        current_steps = 0
        # classification training
        for x, y in data_loader:
            x = x.float().to(device)
            y = y.long().to(device)

            _, cls_logits, loss = model.module.loss(x, y, ood_weight=opt.ood_weight, train_cls=opt.train_cls)

            # record accuracy
            cls_acc = ACC(cls_logits, y)
            train_cls_acc += cls_acc

            # optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            # training several steps each epoch
            current_steps += 1
            if current_steps >= opt.steps_each_epoch:
                break

        # update learning rate
        scheduler.step()

        # evaluation
        if opt.local_rank == 0 and epoch % opt.eval_interval == 0:
            model.eval()
            # ID inference
            with torch.no_grad():
                for x, y in tqdm.tqdm(val_loader):
                    x = x.float().to(device)
                    y = y.long().to(device)

                    _, cls_logits, loss = model.module.loss(x, y, ood_weight=opt.ood_weight, train_cls=opt.train_cls)

                    # record accuracy
                    cls_acc = ACC(cls_logits, y)
                    val_cls_acc += cls_acc
                    val_loss += loss.item()

            # OOD inference
            id_ood_conf_msp, id_ood_conf_ssod, id_ood_label = list(), list(), list()
            with torch.no_grad():
                # ood loader
                for x in tqdm.tqdm(ood_loader):
                    x = x.float().to(device)
                    max_softmax, pred_label, rectified_p = model.module.ood_infer(x)
                    print('OOD Conf: %.4f' % torch.mean(rectified_p))
                    id_ood_conf_msp.extend(max_softmax.detach().squeeze().cpu().numpy().tolist())
                    id_ood_conf_ssod.extend(rectified_p.detach().squeeze().cpu().numpy().tolist())
                    id_ood_label.extend(np.zeros(max_softmax.shape[0]).tolist())
                
                # id loader
                id_count = 0
                for x, _ in tqdm.tqdm(val_loader):
                    x = x.float().to(device)
                    max_softmax, pred_label, rectified_p = model.module.ood_infer(x)
                    print('ID Conf: %.4f' % torch.mean(rectified_p))
                    id_ood_conf_msp.extend(max_softmax.detach().squeeze().cpu().numpy().tolist())
                    id_ood_conf_ssod.extend(rectified_p.detach().squeeze().cpu().numpy().tolist())
                    id_ood_label.extend(np.ones(max_softmax.shape[0]).tolist())
                    id_count += 1
                    if id_count >= len(ood_loader):
                        break

            assert len(id_ood_conf_msp) == len(id_ood_conf_ssod) == len(id_ood_label)

            FPR_msp = FPR(np.array(id_ood_conf_msp), np.array(id_ood_label), threshold=0.95)
            FPR_ssod = FPR(np.array(id_ood_conf_ssod), np.array(id_ood_label), threshold=0.95)

            AUROC_msp = AUROC(np.array(id_ood_conf_msp), np.array(id_ood_label))
            AUROC_ssod = AUROC(np.array(id_ood_conf_ssod), np.array(id_ood_label))

            train_loss = train_loss / len(train_loader)
            train_cls_acc = train_cls_acc / len(train_loader)

            val_loss = val_loss / len(val_loader)
            val_cls_acc = val_cls_acc / len(val_loader)

            print('EPOCH : %04d | Train Loss : %.4f | Train Cls Acc : %.4f | Val Loss : %.4f | Val Cls Acc : %.4f | '
                  'FPR95(MSP) : %.4f | FPR95(SSOD) : %.4f | AUROC(MSP) : %.4f | AUROC(SSOD) : %.4f'
                % (epoch, train_loss, train_cls_acc, val_loss, val_cls_acc, FPR_msp, FPR_ssod, AUROC_msp, AUROC_ssod))

            if FPR_ssod <= opt.best_metric and val_cls_acc > 0.9:
                opt.best_metric = FPR_ssod
                model_name = 'epoch_%d_cls_%.4f_fpr95_ssod_%.4f_fpr95_msp_%.4f_auroc_ssod_%.4f.pth' % (epoch, val_cls_acc, FPR_ssod, FPR_msp, AUROC_ssod)
                os.makedirs(opt.save_path, exist_ok=True)
                torch.save(model.module.state_dict(), '%s/%s' % (opt.save_path, model_name))

            writer.add_scalar('train/loss', train_loss, epoch)
            writer.add_scalar('train/train_cls_acc', train_cls_acc, epoch)

            writer.add_scalar('val/loss', val_loss, epoch)
            writer.add_scalar('val/val_cls_acc', val_cls_acc, epoch)

            writer.add_scalar('FPR95/MSP', FPR_msp, epoch)
            writer.add_scalar('FPR95/SSOD', FPR_ssod, epoch)

            writer.add_scalar('AUROC/MSP', AUROC_msp, epoch)
            writer.add_scalar('AUROC/SSOD', AUROC_ssod, epoch)
        

if __name__ == '__main__':
    parser = argparse.ArgumentParser('SSOD CIFAR-10')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epoch', type=int, default=3000)
    parser.add_argument('--steps_each_epoch', type=int, default=20)
    parser.add_argument('--lr', type=float, default=2e-4)
    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--init_method', default='env://')

    parser.add_argument('--n_gpus', type=int, default=4)
    parser.add_argument('--device_ids', type=str, default='0,1,2,3')
    parser.add_argument('--eval_interval', type=int, default=1)
    parser.add_argument('--build_tensorboard', type=bool, default=True)
    parser.add_argument('--best_metric', type=float, default=0.5)

    parser.add_argument('--depth', type=int, default=18)
    parser.add_argument('--num_classes', type=int, default=10)
    parser.add_argument('--ood_weight', type=float, default=1.2)
    parser.add_argument('--train_cls', type=bool, default=True)
    # parser.add_argument('--ood_type', type=str, default='SVHN')
    # parser.add_argument('--logdir', type=str, default='./tensorboard/ssod/CIFAR-10/SVHN')
    # parser.add_argument('--save_path', type=str, default='./saved_models/cifar_10/ssod_svhn')

    parser.add_argument('--ood_type', type=str, default='Places')
    parser.add_argument('--logdir', type=str, default='./tensorboard/ssod/CIFAR-10/Places')
    parser.add_argument('--save_path', type=str, default='./saved_models/cifar_10/ssod_places_1.2')

    parser.add_argument('--checkpoint', type=str, default='./saved_models/cifar_10/ssod_places_1.0/epoch_39_cls_0.9330_fpr95_ssod_0.0944_fpr95_msp_0.5532_auroc_ssod_0.9774.pth')

    opt = parser.parse_args()
    if opt.local_rank == 0:
        print('opt:', opt)

    main(opt)

# if address already in use, you can use another random master_port
# python3 -m torch.distributed.launch --master_port 9998 --nproc_per_node=4 cifar_train.py --n_gpus=4