import numpy as np
import torch
import tqdm

from src import utils
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.datasets.registry import get_dataset
from src.heads import get_classification_head, get_multihead_classification
from src.linearize import LinearizedImageEncoder
from src.modeling import ImageClassifier, MultiHeadImageClassifier, MultiHeadImageClassifier_full
from src.datasets.utils import apply_trigger, apply_triggerV2, corner_mask_generation
from torchvision import transforms


def eval_single_dataset(image_encoder, dataset_name, args, data_loader=None):
    classification_head = get_classification_head(args, dataset_name)
    model = ImageClassifier(image_encoder, classification_head)

    model.eval()
   
    dataset = get_dataset(
        dataset_name,
        model.val_preprocess,
        location=args.data_location,
        batch_size=args.batch_size,
        seed=args.seed,
    )
    if data_loader is not None:
        dataloader = data_loader
    else:
        dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None)
    device = args.device

    with torch.no_grad():
        top1, correct, n = 0.0, 0.0, 0.0
        for _, data in enumerate(tqdm.tqdm(dataloader)):
            data = maybe_dictionarize(data)
            x = data["images"].to(device)
            y = data["labels"].to(device)

            logits = utils.get_logits(x, model)

            pred = logits.argmax(dim=1, keepdim=True).to(device)

            correct += pred.eq(y.view_as(pred)).sum().item()

            n += y.size(0)

        top1 = correct / n

    metrics = {"top1": top1}
    print(f"Done evaluating on {dataset_name}. Accuracy: {100*top1:.2f}%")

    return metrics


def evaluate(image_encoder, args):
    if args.eval_datasets is None:
        return
    per_dataset_results = {}
    eval_datasets = (
        args.eval_datasets
        if args.control_dataset is None
        else args.eval_datasets + [args.control_dataset]
    )
    for dataset_name in eval_datasets:
        print("Evaluating on", dataset_name)

        results = eval_single_dataset(image_encoder, dataset_name, args)

        print(f"{dataset_name} Top-1 accuracy: {results['top1']:.4f}")
        per_dataset_results[dataset_name + ":top1"] = results["top1"]

    return per_dataset_results


def evaluate_task_vector_at_coef(
    task_vector, pretrained_checkpoint, args, scaling_coef, posthoc_linearization=False
):
    image_encoder = task_vector.apply_to(
        pretrained_checkpoint, scaling_coef=scaling_coef
    )
    if posthoc_linearization:
        pretrained_encoder = task_vector.apply_to(
            pretrained_checkpoint, scaling_coef=0.0
        )
        image_encoder = LinearizedImageEncoder(
            init_encoder=pretrained_encoder, image_encoder=image_encoder, args=args
        )
    coef_info = evaluate(image_encoder, args)

    coef_info = add_normalized_accuracy(coef_info, args)
    coef_info["avg_normalized_top1"] = np.mean(
        [coef_info[dataset + ":normalized_top1"] for dataset in args.eval_datasets]
    )
    coef_info["avg_top1"] = np.mean(
        [coef_info[dataset + ":top1"] for dataset in args.eval_datasets]
    )

    return coef_info


def evaluate_task_vector(
    task_vector, pretrained_checkpoint, args, posthoc_linearization=False
):
    info = {}
    for scaling_coef in np.linspace(0.0, 1.0, args.n_eval_points):
        print(f"Evaluating for scaling coefficient {scaling_coef:.2f}")
        info[scaling_coef] = evaluate_task_vector_at_coef(
            task_vector,
            pretrained_checkpoint,
            args,
            scaling_coef,
            posthoc_linearization,
        )

    return info


def add_normalized_accuracy(results, args):
    for dataset_name in args.eval_datasets:
        results[dataset_name + ":normalized_top1"] = (
            results[dataset_name + ":top1"] / args.finetuning_accuracies[dataset_name]
        )

    return results


def nonlinear_advantage(acc_linear, acc_nonlinear, num_classes):
    err_linear = 1 - acc_linear
    err_nonlinear = 1 - acc_nonlinear
    return (err_linear - err_nonlinear) * num_classes / (num_classes - 1)



def eval_ASR(image_encoder, dataset_name, args, target=None, mode='random', data_loader=None):
    classification_head = get_classification_head(args, dataset_name)
    model = ImageClassifier(image_encoder, classification_head)

    model.eval()
    
    if mode == 'random':
        new_transforms = transforms.Compose([
        model.val_preprocess,
        apply_triggerV2(patch_size=16, patch_location='random')
          ])
    elif mode == 'fixed':
        new_transforms = transforms.Compose([
        model.val_preprocess,
        apply_trigger(patch_size=16, patch_location='top_left_corner')
          ])
    elif mode == 'blended':
        new_transforms = transforms.Compose([
        model.val_preprocess,
        apply_triggerV2(patch_size=16, patch_location='blended', patch_type='blended')
          ])
    elif mode == 'SIG':
        new_transforms = transforms.Compose([
        model.val_preprocess,
        apply_triggerV2(patch_size=16, patch_location='SIG', patch_type='SIG')
          ])
    elif mode == 'warped':
        new_transforms = transforms.Compose([
        model.val_preprocess,
        apply_triggerV2(patch_size=16, patch_location='warped', patch_type='warped')
          ])
    elif mode == 'badmerge_on':
        trigger_path = '../large_scale/saved_triggers/On_CIFAR100_Tgt_1_L_22.npy'
        trigger = np.load(trigger_path)
        trigger = torch.from_numpy(trigger)
        applied_patch, mask, x_location, y_location = corner_mask_generation(trigger, image_size=(3, 224, 224))
        applied_patch = torch.from_numpy(applied_patch)
        mask = torch.from_numpy(mask)
        trigger_applicator  = lambda x: torch.mul(mask.type(torch.FloatTensor), applied_patch.type(torch.FloatTensor)) \
                + torch.mul((1 - mask.expand(x.shape).type(torch.FloatTensor)), x.type(torch.FloatTensor))
            

        new_transforms = transforms.Compose([
        model.val_preprocess,
        trigger_applicator
          ])
    
    
    dataset = get_dataset(
        dataset_name,
        new_transforms,
        location=args.data_location,
        batch_size=args.batch_size,
        seed=args.seed,
    )
    if data_loader is not None:
        dataloader = data_loader
    else:
        dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None)
    device = args.device
    
    with torch.no_grad():
        top1, correct, n = 0.0, 0.0, 0.0
        poison_label = dataset.classnames.index(target)
        print(f"Poison label: {poison_label}")
        for _, data in enumerate(tqdm.tqdm(dataloader)):
            data = maybe_dictionarize(data)
            x = data["images"].to(device)
            y = data["labels"].to(device)
            
            # Create mask for non-banana samples
            non_banana_mask = (y != poison_label)
            
            # Skip iteration if all samples in batch are bananas
            if not non_banana_mask.any():
                continue
            
            # Get predictions for all samples
            logits = utils.get_logits(x, model)
            pred = logits.argmax(dim=1, keepdim=True).to(device)
            
            # Only consider predictions for non-banana samples
            poison_y = torch.full_like(y, poison_label)
            correct += pred[non_banana_mask].eq(
                poison_y[non_banana_mask].view_as(pred[non_banana_mask])
            ).sum().item()
            
            # Only count non-banana samples in total
            n += non_banana_mask.sum().item()

        top1 = correct / n if n > 0 else 0.0
        
    print(f"Done evaluating on {dataset_name}. ASR: {100*top1:.2f}%")
        
    return top1
