
import os
import timm
import torch
from tqdm import tqdm
from models.resnet_ssl import *
from poisoned_datasets import build_eval_dataset, build_extra_dataset, build_train_dataset, build_backdoor_dataset
from train_backdoor import get_args_parser, create_model
from torchvision.datasets import CIFAR10, CIFAR100, GTSRB, ImageFolder
from poisoned_datasets import build_eval_transform, CelebA_attr
from torch.utils.data import DataLoader
from torchvision import transforms

strategy_list = ['knn', 'mean', 'loss_ood']
device = 'cuda:0'
metric = 'cosine'
model_name = 'myresnet18'
# model_name = 'preactresnet18'
# model_name = 'vit_base_patch16_224'
dset = 'GTSRB'
target = 1

strategy = 'loss_ood_10'
ood_classes = 10 if strategy == 'loss_ood_10' else 200
input_size = 32 if 'loss' in strategy else 224
CIFAR10_DEFAULT_MEAN = [0.4914, 0.4822, 0.4465]
CIFAR10_DEFAULT_STD = [0.247, 0.243, 0.261]

parser = get_args_parser()
args = parser.parse_args(f'--data-set {dset} --input-size 32 --data-path data/ --model myresnet18'.split())
if model_name == 'vicreg':
    model, _ = resnet50()
    model.load_state_dict(torch.load('pretrained/resnet50_vicreg.pth'))
    mean, std= (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
    model.forward_features = model.forward
elif model_name == 'resnet50':
    model = timm.create_model('resnet50', pretrained=True)
    mean, std = model.default_cfg['mean'], model.default_cfg['std']
if strategy == 'loss_ood':
    args.nb_classes = ood_classes + 1
    model, mean, std = create_model(args)
    if mean is None:
        mean, std = CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD
    model.load_state_dict(torch.load(f'resources/ood_model/{dset}_{target}/best.pth'))
elif 'loss_ood' in strategy:
    args.nb_classes = ood_classes + 1
    model, mean, std = create_model(args)
    if mean is None:
        mean, std = CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD
    model.load_state_dict(torch.load(f'resources/ood_model/{dset}_{target}_{ood_classes}/best.pth'))
model.eval()

transform = transforms.Compose(
    [transforms.Resize((input_size, input_size), interpolation=2),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
    ]
)

data_path = args.data_path
if 'ood' in strategy:
    dataset_train = build_extra_dataset(args, target, True, transform, num_classes=ood_classes)
elif dset == 'CIFAR10':
    dataset_train = CIFAR10(root=data_path, train=True, transform=transform)
elif dset == 'GTSRB':
    dataset_train = GTSRB(data_path, 'train', transform=transform)
elif dset == 'CELEBATTR':
    dataset_train = CelebA_attr(data_path, True, transform=transform)
elif dset == 'IMAGEWOOF':
    dataset_train = ImageFolder(
            os.path.join(data_path, 'imagewoof2-160', 'train'), # test?
            transform=transform)

loader = DataLoader(dataset_train, batch_size=128, num_workers=8, pin_memory=True, shuffle=False)

# ssl
model.to(device).eval()
feat = []
if 'loss_ood' in strategy:
    target_class = ood_classes
    total, corr = 0, 0
    loss_list = []
    label_list = []
    target_loss_list = []
    with torch.no_grad():
        for img, label in (pbar := tqdm(loader)):
            mask = label == target_class
            if mask.sum() > 0:
                img, label = img.to(device), label.to(device)
                # img = img.to(device)
                target_label = torch.ones_like(label) * target_class
                target_label = target_label.to(torch.int64).to(device)
                logits = model(img)
                # loss = torch.nn.functional.cross_entropy(logits, label, reduction='none')
                target_loss = torch.nn.functional.cross_entropy(logits, target_label, reduction='none')
                total += img.shape[0]
                corr += (logits.argmax(1) == label).sum()
                # loss_list.append(loss)
                label_list.append(label)
                target_loss_list.append(target_loss[label == target_class])
                pbar.set_description(str(corr / total))
    # loss_list = torch.cat(loss_list)
    target_loss_list = torch.cat(target_loss_list)
    # label_list = torch.cat(label_list)
    # breakpoint()
    torch.save(target_loss_list.cpu(), f'resources/loss_ood_{ood_classes}_{dset}_{target}_{model_name}.pth')
else:
    with torch.no_grad():
        for img, label in tqdm(loader):
            if (label == target).sum() > 0:
                if model_name == 'vicreg':
                    h_feat = model(img[label == target].to(device))
                else:
                    h_feat = model.forward_features(img[label == target].to(device))
                    h_feat = model.forward_head(h_feat, pre_logits=True)
                feat.append(h_feat)
    feat = torch.cat(feat)

    if strategy == 'knn':
        if metric == 'cosine':
            feat /= feat.norm(2, dim=1, keepdim=True)
            score = feat @ feat.T
            score[range(len(score)), range(len(score))] = -1
        elif metric == 'l2':
            norm = feat.norm(2, dim=1) ** 2
            score = norm.reshape(-1, 1) + norm.reshape(1, -1) - 2 * feat @ feat.T
            score = - score
            # print(score)
            # breakpoint()
            score[range(len(score)), range(len(score))] = float('-inf')
        # score = score
        
        # mask = torch.ones(len(score), len(score)) - torch.eye(len(score))
        # idx = torch.multinomial(mask, 1500).to(score.device)
        # score = score.gather(1, idx)
        knn = []
        # breakpoint()
        for i in score:
            knn.append(i.sort()[0][-50:].mean())
            # knn.append(i.sort()[0][-100])
        knn = torch.tensor(knn)
        # print(knn.sort()[1][:15])
        torch.save(knn, f'resources/knn_{dset}_{target}_{model_name}.pth')
    elif strategy == 'mean':
        mean = feat.mean(0)
        mean /= mean.norm(2)
        feat /= feat.norm(2, dim=1, keepdim=True)
        # print(mean.shape)
        score = feat @ mean
        torch.save(- score.cpu(), f'resources/mean_{dset}_{target}_{model_name}.pth')