from torch.utils.data import DataLoader
import torch.optim as optim
import torch
import time
import numpy as np
import random
from configs import build_config
from utils import setup_seed, SnapshotLogger
from log import get_logger

from model import XModel, LinearPredictor
from dataset import *

from train import train_func
from test import test_func
from infer import infer_func
import argparse
from copy import deepcopy

import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '4'
import pynvml


def get_idel_gpu():
    pynvml.nvmlInit()
    device_count = pynvml.nvmlDeviceGetCount()
    gpu_usage = []
    for i in range(device_count):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
        gpu_usage.append(utilization.gpu)
    pynvml.nvmlShutdown()
    return gpu_usage.index(min(gpu_usage))

def load_checkpoint(model, ckpt_path, logger):
    if os.path.isfile(ckpt_path):
        print('loading pretrained checkpoint from {}.'.format(ckpt_path))
        weight_dict = torch.load(ckpt_path)
        model_dict = model.state_dict()
        for name, param in weight_dict['model'].items():
            if 'module' in name:
                name = '.'.join(name.split('.')[1:])
            if name in model_dict:
                if param.size() == model_dict[name].size():
                    model_dict[name].copy_(param)
                else:
                    print('{} size mismatch: load {} given {}'.format(
                        name, param.size(), model_dict[name].size()))
            else:
                print('{} not found in model dict.'.format(name))
    else:
        print('Not found pretrained checkpoint file.')


def train(cfg, model, train_loader, test_loader, gt, logger):
    if not os.path.exists(cfg.save_dir):
        os.makedirs(cfg.save_dir)

    criterion = torch.nn.BCELoss()
    criterion2 = torch.nn.KLDivLoss(reduction='batchmean')
    optimizer = optim.Adam(model.parameters(), lr=cfg.lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=60, eta_min=0)

    #linear_model = linear_model.to(torch.device("cuda"))
    # optimizer1 = optim.Adam(linear_model.parameters(), lr=cfg.lr)
    # logger.info('Model:{}\n'.format(model))
    # logger.info('Optimizer:{}\n'.format(optimizer))

    initial_auc, n_far = test_func(test_loader, model, gt, cfg.dataset)
    print('Random initialize {}:{:.4f} FAR:{:.5f}'.format(cfg.metrics, initial_auc, n_far))

    # best_model_wts = copy.deepcopy(model.state_dict())
    best_auc = 0.0
    auc_far = 0.0

    st = time.time()
    for epoch in range(cfg.max_epoch):
        loss1, loss2 = train_func(cfg, train_loader, model, optimizer, criterion, criterion2, cfg.lamda, epoch)
        scheduler.step()

        auc, far = test_func(test_loader, model, gt, cfg.dataset)
        if auc >= best_auc:
            best_auc = auc
            auc_far = far
            # best_model_wts = copy.deepcopy(model.state_dict())

        print('[Epoch:{}/{}]: loss1:{:.4f} loss2:{:.4f} | AUC:{:.4f} FAR:{:.5f}'.format(
            epoch + 1, cfg.max_epoch, loss1, loss2, auc, far))
        logger.check(model, optimizer, scheduler, auc, epoch, True)

    time_elapsed = time.time() - st
    #model.load_state_dict(best_model_wts)
    #torch.save(model.state_dict(), cfg.save_dir + cfg.model_name + '_' + str(round(best_auc, 4)).split('.')[1] + '.pkl')
    print('Training completes in {:.0f}m {:.0f}s | best {}:{:.4f} FAR:{:.5f}\n'.
                format(time_elapsed // 60, time_elapsed % 60, cfg.metrics, best_auc, auc_far))
    return logger.best_auc


def main(cfg):
    import time
    time.sleep(10 * (cfg%8))
    if isinstance(cfg, int):
        ori_cfg = build_config('ucf')

        cfg_list = [deepcopy(ori_cfg)]
        for item, value in ori_cfg.__dict__.items():
            if isinstance(value, list):
                new_cfg_list = []
                for v in value:
                    for c in cfg_list:
                        new_cfg = deepcopy(c)
                        setattr(new_cfg, item, v)
                        new_cfg_list.append(new_cfg)
                cfg_list = new_cfg_list
        print(f'Setting CUDA device to {cfg%8}')
    torch.cuda.set_device(get_idel_gpu())
    cfg = cfg_list[cfg]
    # logger = get_logger(cfg.logs_dir)
    logger = SnapshotLogger(cfg, cfg.dataset, save_all=True)
    
    setup_seed(cfg.seed)
    model = XModel(cfg)
    gt = np.load(cfg.gt)
    device = torch.device("cuda")
    model = model.to(device)
    # logger.info('Config:{}'.format(cfg.__dict__))

    

    if cfg.dataset == 'ucf-crime':
        train_data = UCFDataset(cfg, test_mode=False)
        test_data = UCFDataset(cfg, test_mode=True)
    elif cfg.dataset == 'xd-violence':
        train_data = XDataset(cfg, test_mode=False)
        test_data = XDataset(cfg, test_mode=True)
    elif cfg.dataset == 'shanghaiTech':
        train_data = SHDataset(cfg, test_mode=False)
        test_data = SHDataset(cfg, test_mode=True)
    else:
        raise RuntimeError("Do not support this dataset!")

    train_loader = DataLoader(train_data, batch_size=cfg.train_bs, shuffle=True,
                              num_workers=cfg.workers, pin_memory=True)

    test_loader = DataLoader(test_data, batch_size=cfg.test_bs, shuffle=False,
                             num_workers=cfg.workers, pin_memory=True)

    

    param = sum(p.numel() for p in model.parameters())
    # logger.info('total params:{:.4f}M'.format(param / (1000 ** 2)))
    # load_checkpoint(model, cfg.ckpt_path, logger)
    best_auc = train(cfg, model, train_loader, test_loader, gt, logger)
    return best_auc
    '''if args.mode == 'train':
        # logger.info('Training Mode')
        # load_checkpoint(model, cfg.ckpt_path, logger)
        train(model, train_loader, test_loader, gt, logger)

    elif args.mode == 'infer':
        # logger.info('Test Mode')
        if cfg.ckpt_path is not None:
            load_checkpoint(model, cfg.ckpt_path, logger)
        else:
            print('infer from random initialization')
        infer_func(model, test_loader, gt, logger, cfg)

    else:
        raise RuntimeError('Invalid status!')'''


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='WeaklySupAnoDet')
    parser.add_argument('--dataset', default='ucf', help='anomaly video dataset')
    parser.add_argument('--config-idx', type=int, help='anomaly video dataset')
    #parser.add_argument('--mode', default='train', help='model status: (train or infer)')
    args = parser.parse_args()
    #cfg = build_config(args.dataset)
    main(args.config_idx)
