from pathlib import Path
import inspect
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from numpy import *
import argparse
from PIL import Image
import imageio
import os
from tqdm import tqdm
from utils.metrices import *
import sys
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)


from utils import render
from utils.saver import Saver
from utils.iou import IoU

from data.Imagenet import Imagenet_Segmentation

from baselines.ViT.ViT_explanation_generator import Baselines, LRP
from baselines.ViT.ViT_new import vit_base_patch16_224
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP

from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt

import torch.nn.functional as F

plt.switch_backend('agg')

import lovely_tensors as lt
lt.monkey_patch()

# hyperparameters
num_workers = 0
batch_size = 1


# Args
parser = argparse.ArgumentParser(description='Training multi-class classifier')
parser.add_argument('--arc', type=str, default='vgg', metavar='N',
                    help='Model architecture')
parser.add_argument('--train_dataset', type=str, default='imagenet', metavar='N',
                    help='Testing Dataset')
parser.add_argument('--method', type=str,
                    default='grad_rollout',
                    choices=[ 'rollout', 'lrp','transformer_attribution', 'full_lrp', 'lrp_last_layer',
                              'attn_last_layer', 'attn_gradcam', 
                              'attn_last_layer2',
                              'attn_last_layer3',
                              'attn_last_layer4',
                              'attn_last_layer5',
                              'attn_last_layer6',
                              'attn_last_layer7',
                              'attn_last_layer8',
                              'predmap',
                              'predmap2',
                              'predmap3',
                              'predmap4',
                              'predmap5',
                              'predmap6',
                              'predmap7',
                              'predmap8',
                              'predmap9',
                              'predmap10',
                              'predmap11',
                              'predmap12_attn_prev',
                              'predmap13',
                              'predmap14',
                              'predmap15',
                              'predmap16',
                              'predmap_temperature1_1',
                              'predmap_temperature0_9',
                              'predmap_temperature1_2',
                              'predmap_temperature1_3',
                              'predmap_temperature1_35',
                              'predmap_temperature1_4',
                              'predmap_temperature1_45',
                              'predmap_temperature1_5',
                              'predmap_temperature1_6',
                              ],
                    help='')
parser.add_argument('--thr', type=float, default=0.,
                    help='threshold')
parser.add_argument('--K', type=int, default=1,
                    help='new - top K results')
parser.add_argument('--save-img', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-ia', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-fx', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-fgx', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-m', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-reg', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--is-ablation', type=bool,
                    default=False,
                    help='')
parser.add_argument('--imagenet-seg-path', type=str, required=True)
args = parser.parse_args()



alpha = 2

cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

# Define Saver
args.checkname = "cached_data/" + args.method
saver = Saver(args)

print(f"Experiment Directory: {Path(saver.experiment_dir).resolve()}")

# Data
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
test_img_trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize,
])
test_lbl_trans = transforms.Compose([
    transforms.Resize((224, 224), Image.NEAREST),
])

ds = Imagenet_Segmentation(args.imagenet_seg_path,
                           transform=test_img_trans, target_transform=test_lbl_trans)
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False)

# LRP
model_LRP = vit_LRP(pretrained=True).cuda()
model_LRP.eval()
lrp = LRP(model_LRP)
cache_dir = Path(__file__).resolve().parents[1] / 'cache'
cache_data = torch.load(cache_dir / f"cache.pt")

iterator = tqdm(dl)

def predmap15_from_cache(all_attnmap, xpredmap):
        # all_attnmap # (blocks, heads, tokens-1)
        x = xpredmap # (tokens, classes)
        idx = torch.argmax(x[...,0,:], dim=-1).item()
        
        ############################
        # get all attention from all blocks
        all_attnmap = all_attnmap.flatten(0,1) # (blocks*heads, tokens-1)

        predmap = x[1:, :] # (tokens-1, classes)

        projections = all_attnmap @ predmap # (blocks*heads, classes)
        # TODO: is this needed?
        projections = projections.softmax(dim=0) # (blocks*heads, classes)

        weighted_attnmap = all_attnmap.T @ projections # (tokens-1, classes)
        x = weighted_attnmap * predmap # (tokens-1, classes)
        #############################
        x = x[..., idx] # (tokens-1, )
        x = x.unsqueeze(0) # (1, tokens-1)
        return x # (1, tokens-1)

def predmap16_from_cache(all_attnmap, xpredmap):
        # all_attnmap # (blocks, heads, tokens-1)
        x = xpredmap # (tokens, classes)
        idx = torch.argmax(x[...,0,:], dim=-1).item()
        
        ############################
        # get all attention from all blocks
        # all_attnmap = all_attnmap.flatten(0,1) # (blocks*heads, tokens-1)
        all_attnmap = all_attnmap[-2]

        predmap = x[1:, :] # (tokens-1, classes)
        # predmap = predmap.softmax(dim=-1) # (tokens-1, classes)

        projections = all_attnmap @ predmap # (blocks*heads, classes)
        # TODO: is this needed?
        projections = (projections).softmax(dim=0) # (blocks*heads, classes)

        weighted_attnmap = all_attnmap.T @ projections # (tokens-1, classes)
        x = weighted_attnmap * predmap # (tokens-1, classes)
        #############################
        x = x[..., idx] # (tokens-1, )
        x = x.unsqueeze(0) # (1, tokens-1)
        return x # (1, tokens-1)

def eval_batch(all_attnmap, predmap, labels, method, index):
    method_function = f"{method}_from_cache"
    Res = globals()[method_function](all_attnmap, predmap)
    Res = Res.reshape(1, 1, 14, 14)
    # interpolate to full image size (224,224)
    Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear', align_corners=False)#.cuda()

    # threshold between FG and BG is the mean    
    Res = (Res - Res.min()) / (Res.max() - Res.min())

    ret = Res.mean()

    Res_1 = Res.gt(ret).type(Res.type())
    Res_0 = Res.le(ret).type(Res.type())

    Res_1_AP = Res
    Res_0_AP = 1-Res

    Res_1[Res_1 != Res_1] = 0
    Res_0[Res_0 != Res_0] = 0
    Res_1_AP[Res_1_AP != Res_1_AP] = 0
    Res_0_AP[Res_0_AP != Res_0_AP] = 0


    # TEST
    pred = Res.clamp(min=args.thr) / Res.max()
    pred = pred.view(-1).data.cpu().numpy()
    target = labels.view(-1).data.cpu().numpy()
    # print("target", target.shape)

    output = torch.cat((Res_0, Res_1), 1) # (batch, 2, 224, 224)
    output_AP = torch.cat((Res_0_AP, Res_1_AP), 1) # (batch, 2, 224, 224)

    # Evaluate Segmentation
    batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0
    batch_ap, batch_f1 = 0, 0

    # Segmentation resutls
    correct, labeled = batch_pix_accuracy(output[0].data.cpu(), labels[0])
    inter, union = batch_intersection_union(output[0].data.cpu(), labels[0], 2)
    batch_correct += correct
    batch_label += labeled
    batch_inter += inter
    batch_union += union
    # print("output", output.shape)
    # print("ap labels", labels.shape)
    # ap = np.nan_to_num(get_ap_scores(output, labels))
    ap = np.nan_to_num(get_ap_scores(output_AP, labels))
    # f1 = np.nan_to_num(get_f1_scores(output[0, 1].data.cpu(), labels[0]))
    f1 = 0
    batch_ap += ap
    batch_f1 += f1

    return batch_correct, batch_label, batch_inter, batch_union, batch_ap, batch_f1, pred, target




total_inter, total_union, total_correct, total_label = np.int64(0), np.int64(0), np.int64(0), np.int64(0)
total_ap, total_f1 = [], []

predictions, targets = [], []
for batch_idx, (image, labels) in enumerate(iterator):
    # if batch_idx == 500:
        #  break
    # images = image.cuda()
    # labels = labels.cuda()
    # print("image", image.shape)
    # print("lables", labels.shape)

    all_attnmap = cache_data['all_attnmap'][batch_idx] # (images, blocks, heads, tokens-1)
    predmap = cache_data['xpredmap'][batch_idx] # (images, tokens, classes)
    # all_attnmap = all_attnmap.cuda()
    # predmap = predmap.cuda()

    correct, labeled, inter, union, ap, f1, pred, target = eval_batch(all_attnmap, predmap, labels, args.method, batch_idx)

    predictions.append(pred)
    targets.append(target)

    total_correct += correct.astype('int64')
    total_label += labeled.astype('int64')
    total_inter += inter.astype('int64')
    total_union += union.astype('int64')
    total_ap += [ap]
    total_f1 += [f1]
    pixAcc = np.float64(1.0) * total_correct / (np.spacing(1, dtype=np.float64) + total_label)
    IoU = np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union) # (2, )
    mIoU = IoU.mean()
    mAp = np.mean(total_ap)
    mF1 = np.mean(total_f1)
    iterator.set_description('pixAcc: %.4f, mIoU: %.4f, mAP: %.4f, mF1: %.4f' % (pixAcc, mIoU, mAp, mF1))


method_func = globals()[f"{args.method}_from_cache"]
method_code = inspect.getsource(method_func)
method_code_path = Path(saver.experiment_dir) / "method_code.py"
method_code_path.write_text(method_code)

predictions = np.concatenate(predictions)
targets = np.concatenate(targets)
pr, rc, thr = precision_recall_curve(targets, predictions)
np.save(os.path.join(saver.experiment_dir, 'precision.npy'), pr)
np.save(os.path.join(saver.experiment_dir, 'recall.npy'), rc)


plt.figure()
plt.plot(rc, pr)
pr_curve_path = os.path.join(saver.experiment_dir, 'PR_curve_{}.png'.format(args.method))
print(f"saving pr_curve to {pr_curve_path}")
plt.savefig(pr_curve_path)

txtfile = os.path.join(saver.experiment_dir, 'result_mIoU_%.4f.txt' % mIoU)
# txtfile = 'result_mIoU_%.4f.txt' % mIoU
fh = open(txtfile, 'w')
print("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
print("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
print("Mean AP over %d classes: %.4f\n" % (2, mAp))
print("Mean F1 over %d classes: %.4f\n" % (2, mF1))

fh.write("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
fh.write("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
fh.write("Mean AP over %d classes: %.4f\n" % (2, mAp))
fh.write("Mean F1 over %d classes: %.4f\n" % (2, mF1))
fh.close()
