#!/usr/bin/env python

import argparse
import os
from datetime import datetime

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from datasets.dataset_factory import build_dataset, get_num_classes
from models.model_factory import build_model
from utils.metrics import compute_traditional_ood
from utils.utils import is_debug_session, load_config_yml, set_deterministic
from utils.cal_logits import cal_logits
from utils.feat_extract import extract_features
from utils.get_knn_score import get_knn_score


device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

def calculate_nap_score(model, transform, dataset_name, batch_size, use_gpu, use_tqdm,dataset_dir):
    print(f'Processing {dataset_name} dataset.')
    dataset = build_dataset(dataset_dir, dataset_name, transform, train=False)
    g, seed_worker = set_deterministic()

    # setup dataset
    kwargs = {}
    if torch.cuda.is_available() and not is_debug_session():
        kwargs = {'num_workers': 5, 'pin_memory': True, 'generator': g, 'worker_init_fn': seed_worker}

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, **kwargs)

    with torch.no_grad():
        if use_tqdm:
            progress_bar = tqdm(total=len(dataloader))
        snrs_ = []
        for i, samples in enumerate(dataloader):
            images = samples[0]

            # Create non_blocking tensors for distributed training
            if use_gpu:
                images = images.cuda(non_blocking=True)

            logits, snr = model(images)
            snrs_.append(snr)
            if use_tqdm:
                progress_bar.update()
        snrs_ = torch.cat(snrs_, dim = 0)
        if use_tqdm:
            progress_bar.close()
        return snrs_.cpu().numpy()



def calculate_other_score(model, transform,args, config):
    scoring_method = config['scoring_method']
    print(f'scoring method: {scoring_method}')
    id_scores = None
    ood_scores = []
    g, seed_worker = set_deterministic()
    kwargs = {}
    if torch.cuda.is_available() and not is_debug_session():
        kwargs = {'num_workers': 5, 'pin_memory': True, 'generator': g, 'worker_init_fn': seed_worker}
    
    if scoring_method in ["msp", "energy"]:
        print("calculating in-distribution logits")
        dataset = build_dataset(args.id_dataset_dir, config['id_dataset'], transform, train=False)
        dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=False, **kwargs)
        id_logits_ = cal_logits(model, dataloader, args.use_gpu, args.use_tqdm, device)

        all_ood_logits = []
        print("calculating out-distribution logits")
        for ood_dataset in config['ood_datasets']:
            dataset = build_dataset(args.ood_dataset_dir, ood_dataset, transform, train=False)
            dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=False, **kwargs)
            ood_logits_ = cal_logits(model, dataloader, args.use_gpu, args.use_tqdm, device)
            all_ood_logits.append(ood_logits_)

        if scoring_method == "msp":
            id_scores = np.max(F.softmax(id_logits_, dim=1).detach().cpu().numpy(), axis=1)
            for ood_logits in all_ood_logits:
                ood_scores.append(np.max(F.softmax(ood_logits, dim=1).detach().cpu().numpy(), axis=1))
        elif scoring_method == "energy":
            id_scores = torch.logsumexp(id_logits_.data.cpu(), dim=1).numpy()
            for ood_logits in all_ood_logits:
                ood_scores.append(torch.logsumexp(ood_logits.data.cpu(), dim=1).numpy())
    elif scoring_method == "knn":
        num_classes = get_num_classes(config['id_dataset'])
        id_train_dataset = build_dataset(args.id_dataset_dir, config['id_dataset'], transform, train=True)
        train_loader_in = DataLoader(id_train_dataset, batch_size=config['batch_size'], shuffle=False, **kwargs)
        id_test_dataset = build_dataset(args.id_dataset_dir, config['id_dataset'], transform, train=False)
        test_loader_in = DataLoader(id_test_dataset, batch_size=config['batch_size'], shuffle=False, **kwargs)
        outloaders = []
        for ood_dataset in config['ood_datasets']:
            ood_dataset = build_dataset(args.ood_dataset_dir, ood_dataset, transform, train=False)
            outloaders.append(DataLoader(ood_dataset, batch_size=config['batch_size'], shuffle=False, **kwargs))
        extract_features(model,train_loader_in,test_loader_in, outloaders, num_classes,config['batch_size'],device, config)
        id_scores, ood_scores = get_knn_score(config)

    else:
        raise ValueError("Invalid scoring method")
    return id_scores, ood_scores
    


def find_optimal_w(config, args):
    num_classes = get_num_classes(config['id_dataset'])
    base_dir = os.path.dirname(os.path.abspath(__file__))
    now = str(datetime.now())
    output_dir = os.path.join(base_dir, 'output', now)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # construct the model
    model_snr , model_other, transform = None,None,None
    if config['model_name'].split('_')[0] == "densenet100":
        copy_ = config['model_name']
        config['model_name'] = copy_.split('_')[0]
        model_snr , transform = build_model(config, num_classes=num_classes)
        config['model_name'] = copy_
        p = None
        if 'p' in config:
            p = config['p']
        model_other , transform = build_model(config, num_classes=num_classes, p=p)
    if config['train_restore_file']:
        checkpoint = os.path.join(args.checkpoint_dir, config['train_restore_file'])
        checkpoint = torch.load(checkpoint, map_location='cpu')
        model_snr.load_state_dict(checkpoint)
        model_other.load_state_dict(checkpoint)
    else:
        print('Warning: train_restore_file config not specified')
    model_snr.eval()
    model_other.eval()


    # apply ash
    if config['method'].split('@')[0] in ['ash_b', 'ash_s', 'react','react_and_ash']:
        setattr(model_snr, 'ash_method', config['method'])

    if args.use_gpu:
        model_snr = model_snr.cuda()
        model_other = model_other.cuda()

    # calculate in-distribution NAP scores
    id_snrs = calculate_nap_score(model_snr, transform, config['id_dataset'], config['batch_size'], args.use_gpu, args.use_tqdm,args.id_dataset_dir)

    # calculate out-of-distribution NAP scores
    ood_snrs = []
    for ood_dataset in config['ood_datasets']:
        ood_nap_score = calculate_nap_score(model_snr, transform, ood_dataset, config['batch_size'], args.use_gpu, args.use_tqdm,args.ood_dataset_dir)
        ood_snrs.append(ood_nap_score)

    #calculate other scores

    id_other_scores , ood_other_scores = calculate_other_score(model_other, transform, args, config)

    # calculate optimal w
    Fprs = []
    for w in np.arange(0.1, 1.0, 0.1):
        print("process w:", w)
        name = f"{config['method']} - {config['scoring_method']} - {config['id_dataset']}"
        print(name)

        id_scores = np.power(id_other_scores, w) * np.power(id_snrs, 1 - w)
        f1 = open(os.path.join(output_dir, f"in_scores.txt"), 'w')
        for score in id_scores:
            f1.write("{}\n".format(score))
        f1.close()

        for i,ood_dataset in enumerate(config['ood_datasets']):
            ood_scores = np.power(ood_other_scores[i], w) * np.power(ood_snrs[i], 1 - w)
            f2 = open(os.path.join(output_dir, f"{ood_dataset}.txt"), 'w')
            for score in ood_scores:
                f2.write("{}\n".format(score))
            f2.close()
        
        f3 = open(os.path.join(output_dir, "test.txt"), 'a')
        f3.write("the result when w = {} is:\n".format(w))
        avg_fpr = compute_traditional_ood(output_dir, config['ood_datasets'], config['scoring_method'], f3)
        Fprs.append(avg_fpr)
        f3.close()

    min_fpr = min(Fprs)
    optimal_w = Fprs.index(min_fpr) * 0.1 + 0.1
    return optimal_w



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", required=True, type=str, help="Path to config YML")
    parser.add_argument("--use-gpu", action="store_true", default=True, help="Enables GPU")
    parser.add_argument("--use-tqdm", action="store_true", default=True, help="Enables progress bar")
    parser.add_argument("--ood_dataset_dir", default = './data/CIFAR-10-C') #domain shift datasets direction
    # parser.add_argument("--ood_dataset_dir", default = './data')
    parser.add_argument("--id_dataset_dir", default = './data') 
    parser.add_argument("--checkpoint_dir", default = './checkpoints')
    args = parser.parse_args()
    config = load_config_yml(args.config)
    best_w = find_optimal_w(config, args)
    print("optimal w:", best_w)
