import os
import json
import tqdm
import torch
import numpy as np
import utils
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.datasets.templates import get_templates
from heads import get_classification_head, build_classification_head
from modeling import ImageClassifier, ImageEncoder, ClassificationHead
from src.datasets.registry import get_dataset
import torchvision.utils as vutils
from src.utils import *

def eval_single_dataset(image_encoder, dataset_name, args, backdoor_info=None):
    print("")
    #
    classification_head = get_classification_head(args, dataset_name)
    model = ImageClassifier(image_encoder, classification_head)
    model.eval()
    model = model.to(args.device)

    #
    test_dataset, test_loader = get_dataset(
        dataset_name,
        'test',
        model.val_preprocess,
        location=args.data_location,
        batch_size=args.batch_size
    )
    normalizer = model.val_preprocess.transforms[-1]
    inv_normalizer = NormalizeInverse(normalizer.mean, normalizer.std)
    print("Evaluation Size:", len(test_dataset))

    #### Backdoor Attack ####
    is_backdoor = False
    if backdoor_info is not None:
        is_backdoor = True
    if is_backdoor:
        print(f"========== Evaluate backdoor attack on {dataset_name} ==========")
        non_target_cnt = 0
        backdoored_cnt = 0
        mask = backdoor_info['mask'].to(args.device, dtype=torch.float32)
        applied_patch = backdoor_info['applied_patch'].to(args.device, dtype=torch.float32)
        target_cls = torch.tensor(backdoor_info['target_cls']).to(args.device)
        complement_mask = (1 - mask).to(args.device, dtype=torch.float32)



    device = args.device

    with torch.no_grad():
        top1, correct, n = 0., 0., 0.
        for i, data in enumerate(tqdm.tqdm(test_loader)):
            data = maybe_dictionarize(data)
            x = data['images'].to(args.device, non_blocking=True)
            y = data['labels'].to(args.device, non_blocking=True)

            #indices = data['indices']

            #### Backdoor Attack ####
            if is_backdoor:
                # x = inv_normalizer(x)
                x = mask * applied_patch + complement_mask.expand_as(x) * x
                # x = normalizer(x)

            # x = x.cuda()
            # y = y.cuda()
            logits = utils.get_logits(x, model)
            pred = logits.argmax(dim=1, keepdim=True).to(device)
            correct += pred.eq(y.view_as(pred)).sum().item()
            n += y.size(0)

            #### Backdoor Attack ####
            if is_backdoor:
                non_target_mask = (y != target_cls)  # Boolean mask for non-target indices
                non_target_cnt += non_target_mask.sum().item()  # Count non-target samples

                # Compute backdoor accuracy
                is_target = pred[non_target_mask] == target_cls  # Check if backdoored predictions match the target class
                backdoored_cnt += is_target.sum()  # Keep result as a tensor to avoid unnecessary CPU transfer

        top1 = correct / n

    metrics = {'top1': top1}
    print(f'Accuracy: {100*top1:.2f}%')

    #### Backdoor Attack ####
    if is_backdoor:
        backdoored_cnt = backdoored_cnt.item()
        backdoored_acc = backdoored_cnt/non_target_cnt
        metrics['backdoored_acc'] = backdoored_acc
        metrics['backdoored_cnt'] = backdoored_cnt
        metrics['non_target_cnt'] = non_target_cnt
        print(f'Backdoored accuracy: {100*backdoored_acc:.2f}% ({backdoored_cnt}/{non_target_cnt})')
        print("")
    return metrics

def eval_single_dataset_adv(image_encoder, dataset_name, args, attack=None):
    print("")
    #
    classification_head = get_classification_head(args, dataset_name)
    model = ImageClassifier(image_encoder, classification_head)
    model.eval()

    #
    test_dataset, test_loader = get_dataset(
        dataset_name,
        'test',
        model.val_preprocess,
        location=args.data_location,
        batch_size=args.batch_size,
        fast=True if args.adversary_task == "CIFAR100" else False
    )
    normalizer = model.val_preprocess.transforms[-1]
    inv_normalizer = NormalizeInverse(normalizer.mean, normalizer.std)
    print("Evaluation Size:", len(test_dataset))

    device = args.device
    criterion = torch.nn.CrossEntropyLoss()

    # with torch.no_grad():
    top1, correct, n, correct_n = 0., 0., 0., 0.
    total=0
    total_loss= 0.0
    total_loss_adv = 0.0
    for i, data in enumerate(tqdm.tqdm(test_loader)):
        data = maybe_dictionarize(data)
        x = data['images']
        y = data['labels']
        indices = data['indices']
        adv_img, success, loss_adv = attack.run(
            images=x.to(device),
            labels=y.to(device),
            model=model.to(device),
            precomputed_original_representations=None,
            return_step_by_step=False
        )
        correct_n += y.size(0) - success.sum().item()
        total += y.size(0)
        print(f"Correct_n: {correct_n/total}")
        x = x.to(device)
        y = y.to(device)
        logits = utils.get_logits(x, model)
        pred = logits.argmax(dim=1, keepdim=True).to(device)

        # Calculate cross-entropy loss
        loss = criterion(logits, y).detach()
        total_loss +=loss
        total_loss_adv += loss_adv
        correct += pred.eq(y.view_as(pred)).sum().item()
        n += y.size(0)

    top1 = correct / n
    top1_n = correct_n / total

    metrics = {'accuracy': top1, 'robustness':top1_n, 'loss':total_loss.item(), 'loss_adv':total_loss_adv.item()}
    print(f'Accuracy: {100*top1:.2f}%')
    print(f'Robustness: {100*top1_n:.2f}%')
    print(f'Total loss: {total_loss:.2f}')
    print(f'Total loss adv: {total_loss_adv:.2f}')
    return metrics

def eval_single_dataset_with_frozen_text_encoder(image_encoder, dataset_name, args, backdoor_info=None):
    print("")
    #
    pretrained_clip_model = ImageEncoder(args, keep_lang=True).model
    template = get_templates(dataset_name)
    classification_head = build_classification_head(pretrained_clip_model, dataset_name, template, args.data_location, args.device)
    model = ImageClassifier(image_encoder, classification_head)
    model.eval()

    #
    test_dataset, test_loader = get_dataset(
        dataset_name,
        'test',
        model.val_preprocess,
        location=args.data_location,
        batch_size=args.batch_size
    )
    normalizer = model.val_preprocess.transforms[-1]
    inv_normalizer = NormalizeInverse(normalizer.mean, normalizer.std)
    print("Evaluation Size:", len(test_dataset))


    #### Backdoor Attack ####
    is_backdoor = False
    if backdoor_info is not None:
        is_backdoor = True
    if is_backdoor:
        print(f"========== Evaluate backdoor attack on {dataset_name} ==========")
        non_target_cnt = 0
        backdoored_cnt = 0
        mask = backdoor_info['mask']
        applied_patch = backdoor_info['applied_patch']
        target_cls = backdoor_info['target_cls']
    device = args.device

    with torch.no_grad():
        top1, correct, n = 0., 0., 0.
        for i, data in enumerate(tqdm.tqdm(test_loader)):
            data = maybe_dictionarize(data)
            x = data['images']
            y = data['labels']
            indices = data['indices']

            #### Backdoor Attack ####
            if is_backdoor:
                # x = inv_normalizer(x)
                x = torch.mul(mask.type(torch.FloatTensor), applied_patch.type(torch.FloatTensor)) \
                    + torch.mul((1 - mask.expand(x.shape).type(torch.FloatTensor)), x.type(torch.FloatTensor))
                # x = normalizer(x)
            x = x.cuda()
            y = y.cuda()
            logits = utils.get_logits(x, model)
            pred = logits.argmax(dim=1, keepdim=True).to(device)
            correct += pred.eq(y.view_as(pred)).sum().item()
            n += y.size(0)

            #### Backdoor Attack ####
            if is_backdoor:
                non_target_indices = torch.where(y.cpu()!=target_cls)[0]
                non_target_cnt += len(non_target_indices)
                is_target = pred == target_cls
                backdoored_cnt += is_target[non_target_indices].sum().item()
        top1 = correct / n

    metrics = {'top1': top1}
    print(f'Accuracy: {100*top1:.2f}%')

    #### Backdoor Attack ####
    if is_backdoor:
        backdoored_acc = backdoored_cnt/non_target_cnt
        metrics['backdoored_acc'] = backdoored_acc
        metrics['backdoored_cnt'] = backdoored_cnt
        metrics['non_target_cnt'] = non_target_cnt
        print(f'Backdoored accuracy: {100*backdoored_acc:.2f}% ({backdoored_cnt}/{non_target_cnt})')
        print("")
    return metrics

def eval_single_dataset_head(image_encoder, head, dataset_name, args):
    model = ImageClassifier(image_encoder, head)
    model.eval()
    test_dataset, test_loader = get_dataset(dataset_name, 'test', model.val_preprocess, location=args.data_location,  batch_size=args.batch_size)
    device = args.device

    with torch.no_grad():
        top1, correct, n = 0., 0., 0.
        for i, data in enumerate(tqdm.tqdm(test_loader)):
            data = maybe_dictionarize(data)
            x = data['images'].to(device)
            y = data['labels'].to(device)
            logits = utils.get_logits(x, model)
            pred = logits.argmax(dim=1, keepdim=True).to(device)
            correct += pred.eq(y.view_as(pred)).sum().item()
            n += y.size(0)
        top1 = correct / n

    metrics = {'top1': top1}
    print(f'Done evaluating on {dataset_name}. Accuracy: {100 * top1:.2f}%')
    return metrics

def eval_single_dataset_preprocess_head(image_encoder, head, dataset_name, args):
    model = ImageClassifier(image_encoder, head)
    model.eval()
    test_dataset, test_loader = get_dataset(dataset_name, model.val_preprocess, 'test', location=args.data_location,  batch_size=args.batch_size)
    device = args.device

    with torch.no_grad():
        top1, correct, n = 0., 0., 0.
        for i, data in enumerate(tqdm.tqdm(test_loader)):
            data = maybe_dictionarize(data)
            x = data['images'].to(device)
            y = data['labels'].to(device)
            logits = utils.get_logits(x, model)
            pred = logits.argmax(dim=1, keepdim=True).to(device)
            correct += pred.eq(y.view_as(pred)).sum().item()
            n += y.size(0)
        top1 = correct / n
    metrics = {'top1': top1}
    print(f'Done evaluating on {dataset_name}. Accuracy: {100 * top1:.2f}%')
    return metrics

def evaluate(image_encoder, args, backdoor_info=None):
    if args.eval_datasets is None:
        return
    info = vars(args)
    for i, dataset_name in enumerate(args.eval_datasets):
        print('Evaluating on', dataset_name)

        results = eval_single_dataset(image_encoder, dataset_name, args, backdoor_info)

        for key, val in results.items():
            if 'worst' in key or 'f1' in key.lower() or 'pm0' in key:
                print(f"{dataset_name} {key}: {val:.4f}")
            if backdoor_info is not None:
                info[dataset_name + '-B:' + key] = val # trigger
            else:
                info[dataset_name + ':' + key] = val # clean
    return info