import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from numpy import *
import argparse
from PIL import Image
import imageio
import os
from tqdm import tqdm
from utils.metrices import *
import sys
from pathlib import Path
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)

from utils import render
from utils.saver import Saver
from utils.iou import IoU

from data.Imagenet import Imagenet_Segmentation

from baselines.ViT.ViT_explanation_generator import Baselines, LRP
from baselines.ViT.ViT_new import vit_base_patch16_224
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP

from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt

import torch.nn.functional as F

plt.switch_backend('agg')

import lovely_tensors as lt
lt.monkey_patch()

# hyperparameters
num_workers = 0
batch_size = 1

cls = ['airplane',
       'bicycle',
       'bird',
       'boat',
       'bottle',
       'bus',
       'car',
       'cat',
       'chair',
       'cow',
       'dining table',
       'dog',
       'horse',
       'motobike',
       'person',
       'potted plant',
       'sheep',
       'sofa',
       'train',
       'tv'
       ]

# Args
parser = argparse.ArgumentParser(description='Training multi-class classifier')
parser.add_argument('--arc', type=str, default='vgg', metavar='N',
                    help='Model architecture')
parser.add_argument('--train_dataset', type=str, default='imagenet', metavar='N',
                    help='Testing Dataset')
parser.add_argument('--method', type=str,
                    default='grad_rollout',
                    choices=[ 'rollout', 'lrp','transformer_attribution', 'full_lrp', 'lrp_last_layer',
                              'attn_last_layer', 'attn_gradcam', 
                              'attn_last_layer2',
                              'attn_last_layer3',
                              'attn_last_layer4',
                              'attn_last_layer5',
                              'attn_last_layer6',
                              'attn_last_layer7',
                              'attn_last_layer8',
                              'predmap',
                              'predmap2',
                              'predmap3',
                              'predmap4',
                              'predmap5',
                              'predmap6',
                              'predmap7',
                              'predmap8',
                              'predmap9',
                              'predmap9_all_layers',
                              'predmap10',
                              'predmap11',
                              'predmap12_attn_prev',
                              'attn_all_layers',
                              'predmap13',
                              'predmap14',
                              'predmap_temperature1_1',
                              'predmap_temperature0_9',
                              'predmap_temperature1_2',
                              'predmap_temperature1_3',
                              'predmap_temperature1_35',
                              'predmap_temperature1_4',
                              'predmap_temperature1_45',
                              'predmap_temperature1_5',
                              'predmap_temperature1_6',
                              ],
                    help='')
parser.add_argument('--thr', type=float, default=0.,
                    help='threshold')
parser.add_argument('--K', type=int, default=1,
                    help='new - top K results')
parser.add_argument('--save-img', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-ia', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-fx', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-fgx', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-m', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--no-reg', action='store_true',
                    default=False,
                    help='')
parser.add_argument('--is-ablation', type=bool,
                    default=False,
                    help='')
parser.add_argument('--imagenet-seg-path', type=str, required=True)
parser.add_argument('--use-median', action='store_true', default=False)
args = parser.parse_args()

args.checkname = 'all_layers/' + args.method

alpha = 2

cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

# Define Saver
saver = Saver(args)
saver.results_dir = os.path.join(saver.experiment_dir, 'results')
if not os.path.exists(saver.results_dir):
    os.makedirs(saver.results_dir)
if not os.path.exists(os.path.join(saver.results_dir, 'input')):
    os.makedirs(os.path.join(saver.results_dir, 'input'))
if not os.path.exists(os.path.join(saver.results_dir, 'explain')):
    os.makedirs(os.path.join(saver.results_dir, 'explain'))

args.exp_img_path = os.path.join(saver.results_dir, 'explain/img')
if not os.path.exists(args.exp_img_path):
    os.makedirs(args.exp_img_path)
args.exp_np_path = os.path.join(saver.results_dir, 'explain/np')
if not os.path.exists(args.exp_np_path):
    os.makedirs(args.exp_np_path)

# Data
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
test_img_trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize,
])
test_lbl_trans = transforms.Compose([
    transforms.Resize((224, 224), Image.NEAREST),
])

ds = Imagenet_Segmentation(args.imagenet_seg_path,
                           transform=test_img_trans, target_transform=test_lbl_trans)
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False)

# TODO: DELETE
# Model
# model = vit_base_patch16_224(pretrained=True).cuda()
# baselines = Baselines(model)

# LRP
model_LRP = vit_LRP(pretrained=True).cuda()
model_LRP.eval()
lrp = LRP(model_LRP)

# orig LRP
# model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
# model_orig_LRP.eval()
# orig_lrp = LRP(model_orig_LRP)

# metric = IoU(2, ignore_index=-1)

iterator = tqdm(dl)

# model.eval()
# 

def compute_pred(output):
    pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
    # pred[0, 0] = 282
    # print('Pred cls : ' + str(pred))
    T = pred.squeeze().cpu().numpy()
    T = np.expand_dims(T, 0)
    T = (T[:, np.newaxis] == np.arange(1000)) * 1.0
    T = torch.from_numpy(T).type(torch.FloatTensor)
    Tt = T.cuda()

    return Tt


def eval_batch(image, labels, index, use_median=False):
    # Save input image
    if args.save_img:
        img = image[0].permute(1, 2, 0).data.cpu().numpy()
        img = 255 * (img - img.min()) / (img.max() - img.min())
        img = img.astype('uint8')
        Image.fromarray(img, 'RGB').save(os.path.join(saver.results_dir, 'input/{}_input.png'.format(index)))
        Image.fromarray((labels.repeat(3, 1, 1).permute(1, 2, 0).data.cpu().numpy() * 255).astype('uint8'), 'RGB').save(
            os.path.join(saver.results_dir, 'input/{}_mask.png'.format(index)))

    image.requires_grad = True

    image = image.requires_grad_()
    
    # segmentation test for the rollout baseline
    if args.method == 'rollout':
        res = baselines.generate_rollout(image.cuda(), start_layer=1).reshape(batch_size, 1, 14, 14)
    
    # segmentation test for the LRP baseline (this is full LRP, not partial)
    elif args.method == 'full_lrp':
        res = orig_lrp.generate_LRP(image.cuda(), method="full").reshape(batch_size, 1, 224, 224)
    
    # segmentation test for our method
    elif args.method == 'transformer_attribution':
        res = lrp.generate_LRP(image.cuda(), start_layer=1, method="transformer_attribution").reshape(batch_size, 1, 14, 14)
    
    # elif str(args.method).startswith('predmap'):
        # res = lrp.generate_LRP(image.cuda(), method=args.method).reshape(batch_size, 1, 14, 14)

    # segmentation test for the partial LRP baseline (last attn layer)
    elif args.method == 'lrp_last_layer':
        res = orig_lrp.generate_LRP(image.cuda(), method="last_layer", is_ablation=args.is_ablation)\
            .reshape(batch_size, 1, 14, 14)
    
    # segmentation test for the raw attention baseline (last attn layer)
    elif args.method == 'attn_last_layer':
        res = orig_lrp.generate_LRP(image.cuda(), method="last_layer_attn", is_ablation=args.is_ablation)\
            .reshape(batch_size, 1, 14, 14)
        
    elif str(args.method).startswith('attn_last_layer'):
        res = orig_lrp.generate_LRP(image.cuda(), method=args.method, is_ablation=args.is_ablation)\
            .reshape(batch_size, 1, 14, 14)
    
    # segmentation test for the GradCam baseline (last attn layer)
    elif args.method == 'attn_gradcam':
        res = baselines.generate_cam_attn(image.cuda()).reshape(batch_size, 1, 14, 14)

    
    # predmap9_all_layers
    if str(args.method).startswith('predmap'):
        res = model_LRP.predmap9_all_layers(image.cuda(), idx=None) # (blocks, tokens-1)
    elif str(args.method).startswith('attn_all_layers'):
        res = model_LRP.attn_all_layers(image.cuda(), idx=None) # (blocks, tokens-1)
    num_layers = res.shape[0] 
    res_to_save = res # (blocks, tokens-1)
    
    res = res.unflatten(-1, (14,14)) # (blocks, 14, 14)
    # interpolate to full image size (224,224)
    res = res.reshape(-1, 1, 14, 14) # (blocks, 1, 14, 14)
    res = torch.nn.functional.interpolate(res, scale_factor=16, mode='bilinear', align_corners=False).cuda() # (blocks, 1, 224, 224)
    res = res.squeeze(1) # (blocks, 224, 224)
    
    res = res.flatten(-2) # (blocks, 224*224)

    # threshold between FG and BG is the mean    
    res = (res - res.min(dim=-1, keepdims=True).values) / (res.max(dim=-1, keepdims=True).values - res.min(dim=-1, keepdims=True).values)  # (blocks, pixels)

    if use_median:
        ret = res.median()
    else:
        ret = res.mean(dim=-1) # (blocks, )
    ret = ret.reshape(-1, 1, 1) # (blocks, 1, 1)
    
    res = res.unflatten(-1, (224, 224)) # (blocks, 224, 224)
    
    Res_1 = res.gt(ret).type(res.type())
    Res_0 = res.le(ret).type(res.type())

    Res_1_AP = res
    Res_0_AP = 1-res

    Res_1[Res_1 != Res_1] = 0
    Res_0[Res_0 != Res_0] = 0
    Res_1_AP[Res_1_AP != Res_1_AP] = 0
    Res_0_AP[Res_0_AP != Res_0_AP] = 0


    # TEST
    pred = res.clamp(min=args.thr) / res.amax(dim=(-2,-1), keepdim=True)
    pred = pred.view(-1).data.cpu().numpy()
    target = labels.view(-1).data.cpu().numpy()
    # print("target", target.shape)

    output = torch.stack((Res_0, Res_1), dim=1) # (blocks, 2, 224, 224)
    output_AP = torch.stack((Res_0_AP, Res_1_AP), dim=1) # (blocks, 2, 224, 224)

    if args.save_img:
        # Save predicted mask
        mask = F.interpolate(Res_1, [64, 64], mode='bilinear', align_corners=False)
        mask = mask[0].squeeze().data.cpu().numpy()
        # mask = Res_1[0].squeeze().data.cpu().numpy()
        mask = 255 * mask
        mask = mask.astype('uint8')
        imageio.imsave(os.path.join(args.exp_img_path, 'mask_' + str(index) + '.jpg'), mask)

        relevance = F.interpolate(res, [64, 64], mode='bilinear', align_corners=False)
        relevance = relevance[0].permute(1, 2, 0).data.cpu().numpy()
        # relevance = Res[0].permute(1, 2, 0).data.cpu().numpy()
        hm = np.sum(relevance, axis=-1)
        maps = (render.hm_to_rgb(hm, scaling=3, sigma=1, cmap='seismic') * 255).astype(np.uint8)
        imageio.imsave(os.path.join(args.exp_img_path, 'heatmap_' + str(index) + '.jpg'), maps)

    # Evaluate Segmentation
    batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0
    batch_ap, batch_f1 = 0, 0

    # Segmentation resutls
    correct_lst, labeled_lst, inter_lst, union_lst = [], [], [], []
    for i in range(output.shape[0]): # foreach block i
        correct, labeled = batch_pix_accuracy(output[i].data.cpu(), labels[0])
        inter, union = batch_intersection_union(output[i].data.cpu(), labels[0], 2)
        correct_lst.append(correct)
        labeled_lst.append(labeled)
        inter_lst.append(inter)
        union_lst.append(union)

    batch_correct += correct
    batch_label += labeled
    batch_inter += inter
    batch_union += union
    # print("output", output.shape)
    # print("ap labels", labels.shape)
    # ap = np.nan_to_num(get_ap_scores(output, labels))
    ap_lst = []
    f1_lst = []
    for i in range(output_AP.shape[0]):
        ap = np.nan_to_num(get_ap_scores(output_AP[i].unsqueeze(0), labels))
        ap_lst.append(ap)
        # f1 = np.nan_to_num(get_f1_scores(output[i, 1].data.cpu(), labels[0]))
        # f1_lst.append(f1)
    batch_ap += ap
    # batch_f1 += f1
    
    inter_lst_np = np.array(inter_lst) # (blocks, 2)
    union_lst_np = np.array(union_lst) # (blocks, 2)
    pix_acc = np.float64(1.0) * np.array(correct_lst) / (np.spacing(1, dtype=np.float64) + np.array(labeled_lst)) # (blocks,)
    iou = np.float64(1.0) * inter_lst_np / (np.spacing(1, dtype=np.float64) + union_lst_np) # (blocks, 2)
    iou = iou.mean(axis=-1) # (blocks, )
    ap = np.array(ap_lst).squeeze(-1) # (blocks, )

    return pix_acc, iou, inter_lst_np, union_lst_np, ap, pred, target, res_to_save


total_inter, total_union, total_correct, total_label = np.int64(0), np.int64(0), np.int64(0), np.int64(0)
total_ap, total_f1 = [], []

pix_acc_lst, iou_lst, ap_lst = [], [], []
inter_lst, union_lst = [], []
predictions, targets = [], []
heatmap_lst = []
for batch_idx, (image, labels) in enumerate(iterator):
    if args.method == "blur":
        images = (image[0].cuda(), image[1].cuda())
    else:
        images = image.cuda()
    labels = labels.cuda()
    # print("image", image.shape)
    # print("lables", labels.shape)

    # inter, union (blocks, 2)
    # res_to_save (blocks, tokens-1)
    pix_acc, iou, inter, union, ap, pred, target, res_to_save = eval_batch(images, labels, batch_idx, args.use_median)
    pix_acc_lst.append(pix_acc)
    iou_lst.append(iou)
    ap_lst.append(ap)
    inter_lst.append(inter)
    union_lst.append(union)

    predictions.append(pred)
    targets.append(target)
    heatmap_lst.append(res_to_save.detach().cpu().numpy())

    # total_correct += correct.astype('int64')
    # total_label += labeled.astype('int64')
    # total_inter += inter.astype('int64')
    # total_union += union.astype('int64')
    # total_ap += [ap]
    # total_f1 += [f1]
    # pixAcc = np.float64(1.0) * total_correct / (np.spacing(1, dtype=np.float64) + total_label)
    # IoU = np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union)
    # mIoU = IoU.mean()
    # mAp = np.mean(total_ap)
    # mF1 = np.mean(total_f1)
    # iterator.set_description('pixAcc: %.4f, mIoU: %.4f, mAP: %.4f, mF1: %.4f' % (pixAcc, mIoU, mAp, mF1))

pix_acc_np = np.stack(pix_acc_lst, axis=0) # (index, blocks)
iou_np = np.stack(iou_lst, axis=0) # (index, blocks)
ap_np = np.stack(ap_lst, axis=0) # (index, blocks)
inter_np = np.stack(inter_lst, axis=0) # (index, blocks, 2)
union_np = np.stack(union_lst, axis=0) # (index, blocks, 2)
heatmap_np = np.stack(heatmap_lst, axis=0) # (index, blocks, tokens-1)

np.save(os.path.join(saver.experiment_dir, 'pix_acc.npy'), pix_acc_np)
np.save(os.path.join(saver.experiment_dir, 'iou.npy'), iou_np)
np.save(os.path.join(saver.experiment_dir, 'ap.npy'), ap_np)
np.save(os.path.join(saver.experiment_dir, 'inter.npy'), inter_np)
np.save(os.path.join(saver.experiment_dir, 'union.npy'), union_np)
np.save(os.path.join(saver.experiment_dir, 'heatmap.npy'), heatmap_np)

exit(0)


predictions = np.concatenate(predictions)
targets = np.concatenate(targets)
pr, rc, thr = precision_recall_curve(targets, predictions)
np.save(os.path.join(saver.experiment_dir, 'precision.npy'), pr)
np.save(os.path.join(saver.experiment_dir, 'recall.npy'), rc)

plt.figure()
plt.plot(rc, pr)
pr_curve_path = os.path.join(saver.experiment_dir, 'PR_curve_{}.png'.format(args.method))
print(f"saving pr_curve to {pr_curve_path}")
plt.savefig(pr_curve_path)

txtfile = os.path.join(saver.experiment_dir, 'result_mIoU_%.4f.txt' % mIoU)
# txtfile = 'result_mIoU_%.4f.txt' % mIoU
fh = open(txtfile, 'w')
print("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
print("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
print("Mean AP over %d classes: %.4f\n" % (2, mAp))
print("Mean F1 over %d classes: %.4f\n" % (2, mF1))

fh.write("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
fh.write("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
fh.write("Mean AP over %d classes: %.4f\n" % (2, mAp))
fh.write("Mean F1 over %d classes: %.4f\n" % (2, mF1))
fh.close()

###########################
# pix_acc_lst = []
# iou_lst = []
# ap_lst = []
# def eval_batch():
#     generate_LRP() # (batch, layer, 14, 14)
#     pixAcc # (batch, layer)
#     iou # (batch, layer)
#     ap # (batch, layer)
# eval_batch()
