import random
import numpy as np
import torchvision
import argparse
import inspect
import logging
import math
import os
from pathlib import Path
from typing import Optional
import datasets
import torch
import torch.nn.functional as F
from packaging import version
from torchvision import transforms
from tqdm.auto import tqdm
from sklearn.neighbors import LocalOutlierFactor
from torchvision.utils import make_grid, save_image
from sampler_main import *
from torchmetrics.image import FrechetInceptionDistance as FID
from roc_tpr import cal_roc_tpr
from dataset import *
import lpips
import architectures
import utils
import dnnlib
import pickle


def parse_args():
    parser = argparse.ArgumentParser(description="Test ood with consistency models")
    parser.add_argument('--t1', type=int, default=7)
    parser.add_argument('--t2', type=int, default=7)
    parser.add_argument('--n1', type=int, default=10)
    parser.add_argument('--n2', type=int, default=10)
    parser.add_argument('--seed',type=int, default=1557)
    args = parser.parse_args()



    return args


def load_architecture(name: str = "edm-ncsnpp-cifar10"):
    if name == "edm-ncsnpp-cifar10":
        arch = architectures.SongUNet(
            img_resolution=32,
            in_channels=3,
            out_channels=3,
            embedding_type="fourier",
            encoder_type="residual",
            decoder_type="standard",
            channel_mult_noise=2,
            resample_filter=[1,3,3,1],
            model_channels=128,
            channel_mult=[2,2,2],
            dropout=0.0,
        )

    else:
        raise ValueError(f"Unsupported architecture: {name}")

    return arch


def load_dataset(name: Optional[str] = "cifar10", config: Optional[str] = None, root: Optional[str] = None):
    cache_dir = "./data/huggingface/datasets"
    if name is not None:
        dataset = datasets.load_dataset(name, config, cache_dir=cache_dir, split="train")
    else:
        dataset = datasets.load_dataset("imagefolder", data_dir=root, cache_dir=cache_dir, split="train")

    # Preprocessing the datasets and DataLoaders creation.
    augmentation0 = transforms.Compose([
        transforms.ToTensor(),
    ])

    def transform_images(examples):
        test_images  = [augmentation0(image.convert("RGB")) for image in examples["img"]]
        
        return {"test_images": test_images}

    dataset.set_transform(transform_images)
    return dataset




def main(args):
    print(args)
    model = load_architecture().to('cuda')
    
    
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic=True
    np.random.seed(args.seed)
    random.seed(args.seed)
    init_pos = "./init_cm/diffusion_pytorch_model.bin" 
    pretrained_dict = torch.load(init_pos, map_location='cuda')
    src_dict = pretrained_dict

    npb_model = list(model.named_parameters()) + list(model.named_buffers())

    with torch.no_grad():
        for name, tensor in npb_model:
            assert name in src_dict
            tensor.copy_(src_dict[name])
    
            
    ##########################
    transform_test  = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
    device = 'cuda'
    scheduler = utils.CMScheduler()

    N=18
    timesteps = scheduler.discretize_timesteps(N, device=device)

    # Get the dataset
    cifar10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform = transform_test)
    cifar10_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform = transform_test)
    cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform = transform_test)
    svhn_test = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform = transform_test)
    lsun_test = torchvision.datasets.ImageFolder(root='./data/LSUN_pil', transform = transform_test)
    imgn_test = torchvision.datasets.ImageFolder(root='./data/Imagenet_pil', transform = transform_test)
    model.eval()

    cf10tr_loader = torch.utils.data.DataLoader(cifar10_train, batch_size= 1000)
    cf10te_loader = torch.utils.data.DataLoader(cifar10_test, batch_size = 1000)
    cf100te_loader = torch.utils.data.DataLoader(cifar100_test, batch_size = 1000)
    svhn_loader = torch.utils.data.DataLoader(svhn_test, batch_size=1000)
    lsun_loader = torch.utils.data.DataLoader(lsun_test, batch_size=1000)
    imgn_loader = torch.utils.data.DataLoader(imgn_test, batch_size=1000)
    


    print('cf10',len(cifar10_test))
    print('*'*100)
    cf10_normal = projection_regret(cf10te_loader, model, scheduler, timesteps, args.t1, device, t2=args.t2, n1=args.n1, n2=args.n2)
    print('cf100',len(cifar100_test))
    cf100_normal = projection_regret(cf100te_loader, model, scheduler, timesteps, args.t1, device, t2=args.t2, n1=args.n1, n2=args.n2)
    print('svhn',len(svhn_test))
    svhn_normal = projection_regret(svhn_loader, model, scheduler, timesteps, args.t1, device, t2=args.t2, n1=args.n1, n2=args.n2)
    print('lsun',len(lsun_test))
    lsun_normal = projection_regret(lsun_loader, model, scheduler, timesteps, args.t1, device, t2=args.t2, n1=args.n1, n2=args.n2)
    print('imagenet',len(imgn_test))
    imgn_normal = projection_regret(imgn_loader, model, scheduler, timesteps, args.t1, device, t2=args.t2, n1=args.n1, n2=args.n2)



        
        
    pref=str(args.t1)+'_'+str(args.t2)+'_'+str(args.n1)+'_'+str(args.n2)+'_'+str(args.seed)  
    os.makedirs('numpy_stats_new/projection_regret/'+pref,exist_ok=True)
    shortcut = './numpy_stats_new/projection_regret/'+pref
    np.savetxt(shortcut+'/'+'cifar10_test.txt',cf10_normal)
    np.savetxt(shortcut+'/'+'cifar100.txt',cf100_normal)
    np.savetxt(shortcut+'/'+'svhn.txt', svhn_normal)
    np.savetxt(shortcut+'/'+'lsun.txt',lsun_normal)
    np.savetxt(shortcut+'/'+'imgn.txt',imgn_normal)


    print(args)
    print('==normal stats====')


    print(np.mean(cf10_normal),np.std(cf10_normal))
    print(np.mean(cf100_normal),np.std(cf100_normal))
    print(np.mean(svhn_normal),np.std(svhn_normal))
    print(np.mean(lsun_normal),np.std(lsun_normal))
    print(np.mean(imgn_normal),np.std(imgn_normal))

    print(cal_roc_tpr(cf10_normal,cf100_normal,0.95))
    print(cal_roc_tpr(cf10_normal,svhn_normal,0.95))
    print(cal_roc_tpr(cf10_normal,lsun_normal,0.95))
    print(cal_roc_tpr(cf10_normal,imgn_normal,0.95))


if __name__ == "__main__":
    args = parse_args()
    main(args)

