import argparse
import numpy as np
import os
import time
from pathlib import Path
import torch.nn.functional as F
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision

from torch.utils.data import DataLoader


import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler


from models import DMLR
from sklearn.metrics import roc_auc_score
from omegaconf import OmegaConf
import logging
def get_args_parser():
    parser = argparse.ArgumentParser('LFDN training', add_help=False)
    parser.add_argument('--batch_size', default='', type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter *  ')
    parser.add_argument('--epochs', default='', type=int)
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')

     
    parser.add_argument('--input_size', default='', type=int,
                        help='images input size')
    parser.add_argument('--use_class_label', action='store_true', help='use class label as condition.')
     
    parser.add_argument('--LFDN_model_cfg',  default='', type=str)
    parser.add_argument('--pretrained_LFDN_cfg',  default='', type=str)
    parser.add_argument('--LFDN_steps', default='', type=int)
    parser.add_argument('--eta', default=1.0, type=float)
    parser.add_argument('--log_dir', default='./output_dir_logging', help='Directory where to save logs')

     
    parser.add_argument('--weight_decay', type=float, default='',
                        help='weight decay (default: 0.05)')

    parser.add_argument('--lr', type=float, default='', metavar='LR',
                        help='learning rate (absolute lr)')
    parser.add_argument('--blr', type=float, default='', metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')

    parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
                        help='epochs to warmup LR')

     
    parser.add_argument('--dataset', default='', type=str,
                        help='dataset')
    parser.add_argument('--data_path', default='', type=str,
                        help='dataset path')
    parser.add_argument('--augmentation', default='randresizedcrop', type=str,
                        help='Augmentation type')

    parser.add_argument('--output_dir', default='',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='',
                        help='resume from checkpoint')

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

     
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=3, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')

    return parser

def setup_logging(log_dir):
    try:
        os.makedirs(log_dir, exist_ok=True)
        logging.basicConfig(
            filename=os.path.join(log_dir, "runtime.log"),
            filemode='a',
            format='[%(asctime)s] %(levelname)s: %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S',
            level=logging.INFO
        )

        console = logging.StreamHandler()
        console.setLevel(logging.INFO)
        formatter = logging.Formatter('[%(asctime)s] %(levelname)s: %(message)s')
        console.setFormatter(formatter)
        logging.getLogger('').addHandler(console)
        logging.info("Logging is set up at {}".format(log_dir))
    except Exception as e:
        print("Error setting up logging:", e)
        
def evaluate_ood_detection(model, total_id, ood_data_loaders, device, dataset, config_path):
    logging.info(f"Evaluating OOD detection for dataset: {dataset} with config: {config_path}")


    threshold_total_cos_sim = np.percentile(total_id, 95)

    for idx, data_loader_ood in enumerate(ood_data_loaders, start=1):
        total_ood = compute_metrics(model, data_loader_ood, device)


        ood_detected_total_cos_sim = np.sum(np.array(total_ood) < threshold_total_cos_sim) / len(total_ood)

        logging.info(f"Evaluating OOD Dataset FPR95% {idx}:")
 
        logging.info(f"Total Cosine Similarity FPR95: {ood_detected_total_cos_sim:.4f}")


        auroc_total = roc_auc_score([0]*len(total_id) + [1]*len(total_ood), np.concatenate([ total_id, total_ood]))

        logging.info(f"Total Cosine Similarity AUROC: {auroc_total:.4f}")
        logging.info("--")



def compute_metrics(model, data_loader, device):
    model.eval()

    total_cos_sim_list = []

    with torch.no_grad():
        for samples, class_label in data_loader:
            samples = samples.to(device, non_blocking=True)
            class_label = class_label.to(device, non_blocking=True)

            with torch.cuda.amp.autocast():
                sampled_rep, class_label, origan = model(samples, class_label)

                total_cosine = cosine_similarity_per_block(sampled_rep, origan)

            total_cos_sim_list.extend(total_cosine.cpu().numpy())

    return total_cos_sim_list 

def cosine_similarity_per_block(sampled_rep, origan):
    indices=[0,24, 56, 112, 272, 720]
    weights = [1,1,1,1,1]  
    cosine_sims = []

    for start, end, weight in zip(indices[:-1], indices[1:], weights):

        cos_sim = F.cosine_similarity(sampled_rep[:, start:end], origan[:, start:end], dim=1)
        cosine_sims.append(cos_sim * weight)
    total_cosine_sim = torch.sum(torch.stack(cosine_sims), dim=0)
    return -total_cosine_sim/5

def calculate_auroc(id_scores, ood_scores):

    labels = [0] * len(id_scores) + [1] * len(ood_scores)

    scores = id_scores + ood_scores

    auroc = roc_auc_score(labels, scores)
    return auroc


def main(args):
    logging.getLogger().setLevel(logging.INFO)
    setup_logging(args.log_dir)
    logging.info("Program started")

    misc.init_distributed_mode(args)
    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

     
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    num_tasks = misc.get_world_size()
    global_rank = misc.get_rank()

     
    if global_rank == 0 and args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    transform_train = transforms.Compose([
        transforms.Resize([224,224]),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ])
    if args.dataset == 'cifar10':
        dataset_train = datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform_train)
        dataset_test = datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=transform_train)
        dataset_test0= datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=transform_train)
    elif args.dataset == 'cifar100':
        dataset_train = datasets.CIFAR100(root=args.data_path, train=True, download=True, transform=transform_train)
        dataset_test = datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=transform_train)
        dataset_test0= datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=transform_train)
    dataset_test3 = datasets.ImageFolder(root='', transform=transform_train)
    dataset_test1 = datasets.ImageFolder(root='', transform=transform_train)
    dataset_test5 = datasets.ImageFolder(root='', transform=transform_train)
    dataset_test6 = datasets.ImageFolder(root='', transform=transform_train)
    dataset_test2 = torchvision.datasets.SVHN(root='', split='test',
                                            download=True, transform=transform_train)
    dataset_test4 = datasets.ImageFolder(root='', transform=transform_train)

    print(dataset_train)

    if True:   
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_train = %s" % str(sampler_train))
        sampler_test =  torch.utils.data.DistributedSampler(
            dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_test = %s" % str(sampler_test))
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_test = torch.utils.data.RandomSampler(dataset_test)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,batch_size=args.batch_size,num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True,
    )
    data_loader_test = DataLoader(dataset_test,  batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,pin_memory=args.pin_mem,drop_last=True)
    data_loader_ood1 = DataLoader(dataset_test1, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True)
    data_loader_ood2 = DataLoader(dataset_test2, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True)
    data_loader_ood3 = DataLoader(dataset_test3, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True)
    data_loader_ood4 = DataLoader(dataset_test4, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True)
    data_loader_ood5 = DataLoader(dataset_test5, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True)
    data_loader_ood6 = DataLoader(dataset_test6, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem,drop_last=True)


    model = DMLR(
                                             use_class_label=args.use_class_label,
                                             LFDN_model_cfg=args.LFDN_model_cfg,
                                             pretrained_LFDN_cfg=args.pretrained_LFDN_cfg
    )

    model.to(device)

    model_without_ddp = model
    print("Model = %s" % str(model_without_ddp))

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
    
    if args.lr is None:   
        args.lr = args.blr * eff_batch_size / 256

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)

    print("accumulate grad iterations: %d" % args.accum_iter)
    print("effective batch size: %d" % eff_batch_size)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

     
    n_params = sum(p.numel() for p in model_without_ddp.parameters() if p.requires_grad)
    print("Number of trainable parameters: {}M".format(n_params / 1e6))
    if global_rank == 0:
        log_writer.add_scalar('num_params', n_params / 1e6, 0)
    



    start_time = time.time()

    ood_data_loaders = [data_loader_ood1, data_loader_ood2, data_loader_ood3, data_loader_ood4, data_loader_ood5,data_loader_ood6]



    total_id= compute_metrics(model, data_loader_test, device)
    print(f"Start test:----")


    evaluate_ood_detection(model, total_id,  ood_data_loaders, device, args.dataset, args.pretrained_LFDN_cfg)

if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()

    main(args)

