import torch
from torchvision.models import swin_b, Swin_B_Weights

from .ViT_explanation_generator import Baselines, LRP
from .ViT_LRP import deit_base_patch16_224, vit_base_patch16_224 as vit_LRP
from .ViT_new import vit_base_patch16_224 as vit_new_LRP
from .ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP

from src.diffusion import create_diffusion_model
from src.utilities import min_max_normal
from src.algorithms import dds


def load_baseline_models(model_name, device):
    models=dict()

    if model_name == "vit":
        models['vit'] = vit_new_LRP(pretrained=True).to(device)
    elif model_name == "deit":
        models['vit'] = deit_base_patch16_224(pretrained=True).to(device)
    elif model_name == "swin":
        models['vit'] = swin_b(weights=Swin_B_Weights).to(device)

    models['vit'].eval()
    models['baselines'] = Baselines(models['vit'])

    # LRP
    model_LRP = vit_LRP(pretrained=True).cuda()
    model_LRP.eval()
    models['lrp'] = LRP(model_LRP)

    # orig LRP
    model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
    model_orig_LRP.eval()
    models['orig_lrp'] = LRP(model_orig_LRP)

    models['diffusion'] = create_diffusion_model()

    return models


def run_baselines(models, method, images, device, index, is_ablation=False):
    images = images.to(device)
    # segmentation test for the rollout baseline
    if method == 'rollout':
        Res = models['baselines'].generate_rollout(images, start_layer=1)

    # segmentation test for the GradCam baseline (last attn layer)
    elif method == 'attn_gradcam':
        Res = models['baselines'].generate_cam_attn(images)

    elif method == 'lrp':
        Res = models['lrp'].generate_LRP(images, method="lrp", start_layer=1, index=index)

    # segmentation test for our method
    elif method == 'transformer_attribution':
        Res = models['lrp'].generate_LRP(images, start_layer=1, method="transformer_attribution")
    
    # segmentation test for the LRP baseline (this is full LRP, not partial)
    elif method == 'full_lrp':
        Res = models['orig_lrp'].generate_LRP(images, method="full").reshape(images.shape[0], 1, 224, 224)
    
    # segmentation test for the partial LRP baseline (last attn layer)
    elif method == 'lrp_last_layer':
        Res = models['orig_lrp'].generate_LRP(images, method="last_layer", is_ablation=is_ablation)
            
    # segmentation test for the raw attention baseline (last attn layer)
    elif method == 'attn_last_layer':
        Res = models['orig_lrp'].generate_LRP(images, method="last_layer_attn", is_ablation=is_ablation)
    
    elif method == 'dds':
        dds_images = dds(images, models['diffusion'], fast_predict=True)
        Res = models['lrp'].generate_LRP(dds_images, start_layer=1, method="transformer_attribution")
    
    Res = Res.to(device)
    
    if method != 'full_lrp':
        Res = Res.reshape(images.shape[0], 1, 14, 14)
        # interpolate to full image size (224,224)
        Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear')
    
    # threshold between FG and BG is the mean    
    Res = min_max_normal(Res)

    return Res