import argparse
import numpy as np
import os
import time
from pathlib import Path
import torch.nn.functional as F
import torch
import torch.backends.cudnn as cudnn

from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision

from torch.utils.data import DataLoader


import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler

from torchvision import transforms




from models import DMLR
from sklearn.metrics import roc_auc_score


import logging
def get_args_parser():
    parser = argparse.ArgumentParser('LFDN training', add_help=False)
    parser.add_argument('--batch_size', default=64, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter *  ')
    parser.add_argument('--epochs', default=200, type=int)
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')

     
    parser.add_argument('--input_size', default=224, type=int,
                        help='images input size')
    parser.add_argument('--use_class_label', action='store_true', help='use class label as condition.')
     
    parser.add_argument('--LFDN_model_cfg',  default='', type=str)
    parser.add_argument('--pretrained_LFDN_cfg',  default='', type=str)
    parser.add_argument('--LFDN_steps', default=200, type=int)
    parser.add_argument('--eta', default=1.0, type=float)
    parser.add_argument('--log_dir', default='./output_dir_logging', help='Directory where to save logs')

     
    parser.add_argument('--weight_decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')

    parser.add_argument('--lr', type=float, default=1.5e-4, metavar='LR',
                        help='learning rate (absolute lr)')
    parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')

    parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
                        help='epochs to warmup LR')

     
    parser.add_argument('--dataset', default='', type=str,
                        help='dataset')
    parser.add_argument('--data_path', default='', type=str,
                        help='dataset path')
    parser.add_argument('--augmentation', default='randresizedcrop', type=str,
                        help='Augmentation type')

    parser.add_argument('--output_dir', default='',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='',
                        help='resume from checkpoint')

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

     
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=3, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')

    return parser

def setup_logging(log_dir):
    try:

        os.makedirs(log_dir, exist_ok=True)

        logging.basicConfig(
            filename=os.path.join(log_dir, "runtime.log"),
            filemode='a',
            format='[%(asctime)s] %(levelname)s: %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S',
            level=logging.INFO
        )

        console = logging.StreamHandler()
        console.setLevel(logging.INFO)
        formatter = logging.Formatter('[%(asctime)s] %(levelname)s: %(message)s')
        console.setFormatter(formatter)
        logging.getLogger('').addHandler(console)
        logging.info("Logging is set up at {}".format(log_dir))
    except Exception as e:
        print("Error setting up logging:", e)

def evaluate_ood_detection(model, MSE_id, ood_data_loaders, lvlb_weight_at_t, device, dataset, config_path):
    logging.info(f"Evaluating OOD detection for dataset: {dataset} with config: {config_path}")

    threshold = np.percentile(MSE_id, 95)

    for idx, data_loader_ood in enumerate(ood_data_loaders, start=1):
        MSE_ood = compute_metrics(model, data_loader_ood,lvlb_weight_at_t, device)

        ood_detected_fpr = np.sum(np.array(MSE_ood ) <= threshold) / len(MSE_ood )

        logging.info(f"Evaluating OOD Dataset FPR95% {idx}:")
        logging.info(f"FPR95: {ood_detected_fpr:.4f}")


        auroc = roc_auc_score([0]*len(MSE_id) + [1]*len(MSE_ood ), np.concatenate([MSE_id, MSE_ood ]))

        logging.info(f"AUROC: {auroc:.4f}")
        logging.info("--")  



def compute_MSE(origan, sampled_rep,lvlb_weight_at_t):
    mseloss = F.mse_loss(origan, sampled_rep, reduction='none')
    mseloss = mseloss.sum(dim=1)  
    return mseloss


def compute_metrics(model, data_loader, lvlb_weight_at_t, device):
    model.eval()
    MSE_list = []

    with torch.no_grad():
        for samples, class_label in data_loader:
            samples = samples.to(device, non_blocking=True)
            class_label = class_label.to(device, non_blocking=True)
            sampled_rep, class_label, origan = model(samples, class_label)
            MSE = compute_MSE(origan, sampled_rep, lvlb_weight_at_t)
            Sx = MSE 
            MSE_list.extend(Sx.cpu().numpy().tolist())  

    return MSE_list


def calculate_auroc(id_scores, ood_scores):

    labels = [0] * len(id_scores) + [1] * len(ood_scores)

    scores = id_scores + ood_scores

    auroc = roc_auc_score(labels, scores)
    return auroc


def main(args):

    logging.getLogger().setLevel(logging.INFO)
    setup_logging(args.log_dir)
    logging.info("Program started")

    misc.init_distributed_mode(args)
    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

     
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    num_tasks = misc.get_world_size()
    global_rank = misc.get_rank()

     
    if global_rank == 0 and args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    transform_train = transforms.Compose([
        transforms.Resize([224,224]),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ])
    transform_train1 = transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.Grayscale(num_output_channels=3),  
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    if args.dataset == 'cifar10':
        dataset_train = datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform_train)
        dataset_test = datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=transform_train)
        dataset_test0= datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=transform_train)
    elif args.dataset == 'cifar100':
        dataset_train = datasets.CIFAR100(root=args.data_path, train=True, download=True, transform=transform_train)
        dataset_test = datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=transform_train)
        dataset_test0= datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=transform_train)
    elif args.dataset == 'CelebA':
        dataset_train = datasets.ImageFolder(root='' ,transform=transform_train)
        dataset_test = datasets.ImageFolder(root='' ,transform=transform_train)

    dataset_test2 = datasets.ImageFolder(root='', transform=transform_train)


    dataset_test1 = torchvision.datasets.SVHN(root='', split='test',
                                            download=True, transform=transform_train)

     
    dataset_test4 = datasets.FashionMNIST(root='', train=False, download=False, transform=transform_train1)


     
    dataset_test3 = datasets.MNIST(root='', train=False, download=False, transform=transform_train1)

     
    dataset_test5 = datasets.KMNIST(root='', train=False, download=False, transform=transform_train1)


     
    dataset_test6 = datasets.Omniglot(root='', download=False, background=False, transform=transform_train1)
     
    dataset_test7= datasets.ImageFolder(root='' ,transform=transform_train1)

    print(dataset_train)

    if True:   
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_train = %s" % str(sampler_train))
        sampler_test =  torch.utils.data.DistributedSampler(
            dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_test = %s" % str(sampler_test))
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_test = torch.utils.data.RandomSampler(dataset_test)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,batch_size=args.batch_size,num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True,
    )
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, sampler=sampler_test,batch_size=args.batch_size,num_workers=args.num_workers,pin_memory=args.pin_mem,drop_last=True,
    )
    
    data_loader_ood1 = DataLoader(dataset_test1, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True)
    data_loader_ood2 = DataLoader(dataset_test2, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True)
    data_loader_ood3 = DataLoader(dataset_test3, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True)
    data_loader_ood4 = DataLoader(dataset_test4, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True)
    data_loader_ood5 = DataLoader(dataset_test5, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True)
    data_loader_ood6 = DataLoader(dataset_test6, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True)
    data_loader_ood7 = DataLoader(dataset_test7, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True)


    model = DMLR(
                                             use_class_label=args.use_class_label,
                                             LFDN_model_cfg=args.LFDN_model_cfg,
                                             pretrained_LFDN_cfg=args.pretrained_LFDN_cfg
    )

    model.to(device)

    model_without_ddp = model
    print("Model = %s" % str(model_without_ddp))

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
    
    if args.lr is None:   
        args.lr = args.blr * eff_batch_size / 256

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)

    print("accumulate grad iterations: %d" % args.accum_iter)
    print("effective batch size: %d" % eff_batch_size)


    print(f"Start test:----")
    start_time = time.time()

    ood_data_loaders = [data_loader_ood1, data_loader_ood2, data_loader_ood3, data_loader_ood4, data_loader_ood5,data_loader_ood6,data_loader_ood7]



    MSE_id= compute_metrics(model, data_loader_test,1, device)


    evaluate_ood_detection(model, MSE_id ,ood_data_loaders,1, device, args.dataset, args.pretrained_LFDN_cfg)

if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()

    main(args)

