import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os
import numpy as np
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.models import vit_b_16,VisionTransformer
import torchvision.models as models
from torch.optim.lr_scheduler import ReduceLROnPlateau
import PIL.Image as Image
import matplotlib.pyplot as plt
from tqdm import tqdm

import torchvision
import torch
import numpy as np
import matplotlib.pyplot as plt
from captum.attr import IntegratedGradients
from captum.attr import NoiseTunnel

from copy import deepcopy


device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_dataset = torchvision.datasets.CIFAR100(root='../../data', train=False, download=True, transform=transform_test)
val_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, shuffle=False)


# Hyperparameters
image_size = 32
patch_size = 4  # We will divide the 32x32 image into 4x4 patches
num_patches = (image_size // patch_size) ** 2  # Number of patches
embedding_dim = 128  # Dimensionality of patch embeddings
num_heads = 8  # Number of heads in multi-head attention
num_layers = 6  # Number of transformer layersHo
num_classes = 100  # CIFAR-100 has 100 classes
hidden_dim = 256 # Hidden dimension in the feed-forward network
dropout_rate = 0.1  # Dropout rate

# Custom Transformer for CIFAR-100
class TransformerClassifier(nn.Module):
    def __init__(self, num_patches, embedding_dim, num_heads, num_layers, hidden_dim, num_classes, dropout_rate):
        super(TransformerClassifier, self).__init__()

        # Patch embedding: Flatten the patch and apply a linear layer
        self.patch_embedding = nn.Linear(3 * patch_size * patch_size, embedding_dim)

        # Positional encoding: Learnable positional embeddings
        self.position_embedding = nn.Parameter(torch.randn(1, num_patches, embedding_dim))

        # Transformer Encoder: Multiple layers of multi-head attention + feed-forward network
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout_rate,
                                                   batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Classification Head
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, num_classes)
        )

    def forward(self, x):
        # Input x shape: (batch_size, 3, 32, 32)

        batch_size = x.shape[0]

        # Split image into patches: (batch_size, num_patches, 3*patch_size*patch_size)
        x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
        x = x.permute(0, 2, 3, 1, 4, 5).contiguous().view(batch_size, -1, 3 * patch_size * patch_size)

        # Apply patch embedding: (batch_size, num_patches, embedding_dim)
        x = self.patch_embedding(x)

        # Add positional encoding
        x += self.position_embedding

        # Pass through the Transformer Encoder
        x = self.transformer(x)

        # Use the embedding of the first token for classification (like in BERT/ViT)
        x = x.mean(dim=1)  # Pooling to get a single embedding per image

        # Apply the classification head
        x = self.mlp_head(x)

        return x



def explain_ori():
    model = TransformerClassifier(num_patches=num_patches, embedding_dim=embedding_dim, num_heads=num_heads,
                                  num_layers=num_layers, hidden_dim=hidden_dim, num_classes=num_classes,
                                  dropout_rate=dropout_rate)
    # resnet18_init = models.resnet18(pretrained=True)
    # num_ftrs = resnet18_init.fc.in_features
    # resnet18_init.fc = nn.Linear(num_ftrs, 100)
    state_dict = torch.load('../../data/cifar100/IG/transformer-finetune.pth',map_location='cpu')
    model.load_state_dict(state_dict['net'])
    model.to(device)
    model.eval()

    ig = IntegratedGradients(model)
    nt = NoiseTunnel(ig)

    explanations_list = []
    for inputs, labels in tqdm(val_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        # attributions, delta = ig.attribute(inputs, target=labels, return_convergence_delta=True)
        attributions = nt.attribute(inputs, nt_samples=10, nt_type='smoothgrad_sq',
                                                    target=labels.to(torch.int))

        attributions = attributions.squeeze().cpu().detach().numpy()
        explanations_list.append(attributions)
    # explanations_list = np.stack(explanations_list,axis=0)
    np.save('../../data/cifar100/IG/transformer-explanations-finetune',explanations_list)

def generate_explanations(random = 0.2):
    dir_path = '../../data/cifar100/IG/transformer-explanations-finetune.npy'
    explanations_list = np.load(dir_path)
    explanations_list = explanations_list.sum(axis=1)

    if not os.path.exists(dir_path.replace('.npy','_random%f_%d.npy'%(0.0,0))):
        np.save(dir_path.replace('.npy','_random%f_%d.npy'%(0.0,0)),explanations_list)

    for seed in range(5):
        explanations = []
        for explanation in explanations_list:
            scores = deepcopy(explanation) #explaination_list[count].copy()
            ori_shape = scores.shape
            # scores = scores.sum(axis=0)
            scores = scores.reshape(-1)

            num = int(random * scores.shape[0])
            select = np.arange(scores.shape[0])
            idx_select = np.random.choice(select, num, replace=False)
            scores[idx_select] = np.random.permutation(scores[idx_select])

            explanations.append(scores.reshape(ori_shape))
        np.save(dir_path.replace('.npy','_random%f_%d.npy'%(random,seed)) , explanations)

if __name__=="__main__":
    explain_ori()
    for random in [0.2,0.4,0.6,0.8,1.0]:
        generate_explanations(random)


