import warnings
from sklearn.exceptions import UndefinedMetricWarning

warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

import numpy as np
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt
import argparse
from PIL import Image
import imageio
import os
from tqdm import tqdm
import sys
sys.path.append(os.getcwd())

import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader

from utils.metrices import *
from utils import render
from utils.saver import Saver

from data.Imagenet import Imagenet_Segmentation
from baselines.misc_functions import *
from baselines.models import run_baselines, load_baseline_models

from src.algorithms import dds, batch_pgd_attack
from src.diffusion import create_diffusion_model

plt.switch_backend('agg')

# hyperparameters
num_workers = 0
batch_size = 1

cls = ['airplane',
       'bicycle',
       'bird',
       'boat',
       'bottle',
       'bus',
       'car',
       'cat',
       'chair',
       'cow',
       'dining table',
       'dog',
       'horse',
       'motobike',
       'person',
       'potted plant',
       'sheep',
       'sofa',
       'train',
       'tv'
       ]

alpha = 2

def compute_pred(output):
    pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
    # pred[0, 0] = 282
    # print('Pred cls : ' + str(pred))
    T = pred.squeeze().cpu().numpy()
    T = np.expand_dims(T, 0)
    T = (T[:, np.newaxis] == np.arange(1000)) * 1.0
    T = torch.from_numpy(T).type(torch.FloatTensor)
    Tt = T.cuda()

    return Tt


def eval_batch(models, image, labels, evaluator, index, device):
    evaluator.zero_grad()
    # Save input image
    if args.save_img:
        img = image[0].permute(1, 2, 0).data.cpu().numpy()
        img = 255 * (img - img.min()) / (img.max() - img.min())
        img = img.astype('uint8')
        Image.fromarray(img, 'RGB').save(os.path.join(saver.results_dir, 'input/{}_input.png'.format(index)))
        Image.fromarray((labels.repeat(3, 1, 1).permute(1, 2, 0).data.cpu().numpy() * 255).astype('uint8'), 'RGB').save(
            os.path.join(saver.results_dir, 'input/{}_mask.png'.format(index)))

    image.requires_grad = True

    image = image.requires_grad_()
    predictions = evaluator(image)
    
    Res = run_baselines(models, args.method, image, device, index, args.is_ablation)

    ret = Res.mean()

    Res_1 = Res.gt(ret).type(Res.type())
    Res_0 = Res.le(ret).type(Res.type())

    Res_1_AP = Res
    Res_0_AP = 1-Res

    Res_1[Res_1 != Res_1] = 0
    Res_0[Res_0 != Res_0] = 0
    Res_1_AP[Res_1_AP != Res_1_AP] = 0
    Res_0_AP[Res_0_AP != Res_0_AP] = 0


    # TEST
    pred = Res.clamp(min=args.thr) / Res.max()
    pred = pred.view(-1).data.cpu().numpy()
    target = labels.view(-1).data.cpu().numpy()
    # print("target", target.shape)

    output = torch.cat((Res_0, Res_1), 1)
    output_AP = torch.cat((Res_0_AP, Res_1_AP), 1)

    if args.save_img:
        # Save predicted mask
        mask = F.interpolate(Res_1, [64, 64], mode='bilinear')
        mask = mask[0].squeeze().data.cpu().numpy()
        # mask = Res_1[0].squeeze().data.cpu().numpy()
        mask = 255 * mask
        mask = mask.astype('uint8')
        imageio.imsave(os.path.join(args.exp_img_path, 'mask_' + str(index) + '.jpg'), mask)

        relevance = F.interpolate(Res, [64, 64], mode='bilinear')
        relevance = relevance[0].permute(1, 2, 0).data.cpu().numpy()
        # relevance = Res[0].permute(1, 2, 0).data.cpu().numpy()
        hm = np.sum(relevance, axis=-1)
        maps = (render.hm_to_rgb(hm, scaling=3, sigma=1, cmap='seismic') * 255).astype(np.uint8)
        imageio.imsave(os.path.join(args.exp_img_path, 'heatmap_' + str(index) + '.jpg'), maps)

    # Evaluate Segmentation
    batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0
    batch_ap, batch_f1 = 0, 0

    # Segmentation resutls
    correct, labeled = batch_pix_accuracy(output[0].data.cpu(), labels[0])
    inter, union = batch_intersection_union(output[0].data.cpu(), labels[0], 2)
    batch_correct += correct
    batch_label += labeled
    batch_inter += inter
    batch_union += union
    # print("output", output.shape)
    # print("ap labels", labels.shape)
    # ap = np.nan_to_num(get_ap_scores(output, labels))
    ap = np.nan_to_num(get_ap_scores(output_AP, labels))
    f1 = np.nan_to_num(get_f1_scores(output[0, 1].data.cpu(), labels[0]))
    batch_ap += ap
    batch_f1 += f1

    return batch_correct, batch_label, batch_inter, batch_union, batch_ap, batch_f1, pred, target


def main(args):
    # Define Saver
    saver = Saver(args)
    saver.results_dir = os.path.join(saver.experiment_dir, 'results')
    if not os.path.exists(saver.results_dir):
        os.makedirs(saver.results_dir)
    if not os.path.exists(os.path.join(saver.results_dir, 'input')):
        os.makedirs(os.path.join(saver.results_dir, 'input'))
    if not os.path.exists(os.path.join(saver.results_dir, 'explain')):
        os.makedirs(os.path.join(saver.results_dir, 'explain'))

    args.exp_img_path = os.path.join(saver.results_dir, 'explain/img')
    if not os.path.exists(args.exp_img_path):
        os.makedirs(args.exp_img_path)
    args.exp_np_path = os.path.join(saver.results_dir, 'explain/np')
    if not os.path.exists(args.exp_np_path):
        os.makedirs(args.exp_np_path)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Data
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    test_img_trans = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        normalize,
    ])
    test_lbl_trans = transforms.Compose([
        transforms.Resize((224, 224), Image.NEAREST),
    ])

    ds = Imagenet_Segmentation(args.imagenet_seg_path,
                            transform=test_img_trans, target_transform=test_lbl_trans)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False)

    models=load_baseline_models(args.model, device)
    diffusion_model = create_diffusion_model()

    # metric = IoU(2, ignore_index=-1)

    iterator = tqdm(dl)

    total_inter, total_union, total_correct, total_label = np.int64(0), np.int64(0), np.int64(0), np.int64(0)
    total_ap, total_f1 = [], []

    predictions, targets = [], []
    for batch_idx, (image, labels) in enumerate(iterator):

        if args.method == "blur":
            images = (image[0].cuda(), image[1].cuda())
        else:
            images = image.cuda()
        labels = labels.cuda()
        
        if args.attack:
            images = batch_pgd_attack(images, models['vit'], noise_level=2/255)
            
        # adding the stabilised dds step
        if args.dds:
            images = dds(images, diffusion_model, fast_predict=True).unsqueeze(0)
            
        correct, labeled, inter, union, ap, f1, pred, target = eval_batch(models, images, labels, models['vit'], batch_idx, device)

        predictions.append(pred)
        targets.append(target)

        total_correct += correct.astype('int64')
        total_label += labeled.astype('int64')
        total_inter += inter.astype('int64')
        total_union += union.astype('int64')
        total_ap += [ap]
        total_f1 += [f1]
        pixAcc = np.float64(1.0) * total_correct / (np.spacing(1, dtype=np.float64) + total_label)
        IoU = np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union)
        mIoU = IoU.mean()
        mAp = np.mean(total_ap)
        mF1 = np.mean(total_f1)
        iterator.set_description('pixAcc: %.4f, mIoU: %.4f, mAP: %.4f, mF1: %.4f' % (pixAcc, mIoU, mAp, mF1))

    predictions = np.concatenate(predictions)
    targets = np.concatenate(targets)
    pr, rc, thr = precision_recall_curve(targets, predictions)
    np.save(os.path.join(saver.experiment_dir, 'precision.npy'), pr)
    np.save(os.path.join(saver.experiment_dir, 'recall.npy'), rc)

    plt.figure()
    plt.plot(rc, pr)
    plt.savefig(os.path.join(saver.experiment_dir, 'PR_curve_{}.png'.format(args.method)))

    txtfile = os.path.join(saver.experiment_dir, 'result_mIoU_%.4f.txt' % mIoU)
    # txtfile = 'result_mIoU_%.4f.txt' % mIoU
    fh = open(txtfile, 'w')
    print("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
    print("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
    print("Mean AP over %d classes: %.4f\n" % (2, mAp))
    print("Mean F1 over %d classes: %.4f\n" % (2, mF1))

    fh.write("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
    fh.write("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
    fh.write("Mean AP over %d classes: %.4f\n" % (2, mAp))
    fh.write("Mean F1 over %d classes: %.4f\n" % (2, mF1))
    fh.close()


if __name__ == "__main__":
    # Args
    parser = argparse.ArgumentParser(description='Training multi-class classifier')
    parser.add_argument('--arc', type=str, default='vgg', metavar='N',
                        help='Model architecture')
    parser.add_argument('--train_dataset', type=str, default='imagenet', metavar='N',
                        help='Testing Dataset')
    parser.add_argument('--method', type=str, default='rollout', help='',
                        choices=[ 'rollout', 'lrp','transformer_attribution', 'full_lrp', 'lrp_last_layer',
                                'attn_last_layer', 'attn_gradcam'])
    parser.add_argument("--model", type=str, default="vit",
                        choices=["vit", "deit", "swin"],
                        help="enter the vision transformer to be used")
    parser.add_argument('--thr', type=float, default=0.,
                        help='threshold')
    parser.add_argument('--K', type=int, default=1,
                        help='new - top K results')
    parser.add_argument('--save-img', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--no-ia', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--no-fx', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--no-fgx', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--no-m', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--no-reg', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--is-ablation', type=bool,
                        default=False,
                        help='')
    parser.add_argument('--imagenet-seg-path', \
                        default="data_dir/gtsegs_ijcv.mat", 
                        type=str)
    parser.add_argument('--attack', type=bool,
                        default=False,
                        help='')
    parser.add_argument('--dds', type=bool,
                        default=False,
                        help='')
    args = parser.parse_args()
    
    # Unused args
    # parser.add_argument('--no-ia', action='store_true', default=False, help='')
    # parser.add_argument('--no-fx', action='store_true', default=False, help='')
    # parser.add_argument('--no-fgx', action='store_true', default=False, help='')
    # parser.add_argument('--no-m', action='store_true', default=False, help='')
    # parser.add_argument('--no-reg', action='store_true', default=False, help='')
    # parser.add_argument('--K', type=int, default=1, help='new - top K results')
    

    args.checkname = args.method + '_' + args.arc

    main(args)