import copy
import os.path as osp
import argparse
import click
import cv2
import matplotlib.cm as cm
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.datasets as dset
import random
from torchvision import models, transforms
from check_dataset import check_dataset
from check_model import check_model
from utils import resnet_icml_ilsvrc


from grad_cam import (
    BackPropagation,
    Deconvnet,
    GradCAM,
    GradCAMATL,
    GuidedBackPropagation,
    occlusion_sensitivity,
)

import os
import sys
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from torchvision import transforms as T

def choose_model(opt):
    if (opt.target_class == 'source'):
        model = resnet_icml_ilsvrc.__dict__[opt.source_model](pretrained=opt.pretrained_src_model)
    elif (opt.target_class == 'indep'):
        model = resnet_icml_ilsvrc.__dict__[opt.target_model](num_classes=opt.num_target_classes, pretrained=opt.pretrained_tg_model).to(opt.device)
        model.load_state_dict(torch.load('models/cub200/indep_best.pth', map_location=torch.device(opt.device))['target_model'])
    elif (opt.target_class == 'xstitch'):
        opt.ru_units=False
        model = resnet_icml_ilsvrc.__dict__[opt.target_model](num_classes=opt.num_target_classes, pretrained=opt.pretrained_tg_model,
                                                          transfer_types=opt.transfer_types,
                                                          source_info=opt.source_info,
                                                          ru_units=opt.ru_units)
        ckpt = torch.load('models/cub200/xstitch_best.pth', map_location=torch.device(opt.device))
        opt.state_pairs = ckpt['pairs']
        model.load_state_dict(ckpt['target_model'])
    elif (opt.target_class == 'bandit'):
        opt.ru_units=True
        model = resnet_icml_ilsvrc.__dict__[opt.target_model](num_classes=opt.num_target_classes, pretrained=opt.pretrained_tg_model,
                                                          transfer_types=opt.transfer_types,
                                                          source_info=opt.source_info,
                                                          ru_units=opt.ru_units)
        ckpt = torch.load('models/cub200/bandit_best.pth', map_location=torch.device(opt.device))
        opt.state_pairs = ckpt['pairs']
        model.load_state_dict(ckpt['target_model'])
    return model

def load_images(image_paths):
    images = []
    raw_images = []
    print("Images:")
    for i, image_path in enumerate(image_paths):
        print("\t#{}: {}".format(i, image_path))
        image, raw_image = preprocess(image_path)
        images.append(image)
        raw_images.append(raw_image)
    return images, raw_images

def preprocess(image_path):
    raw_image = cv2.imread(image_path)
    raw_image = cv2.resize(raw_image, (224,) * 2)
    image = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )(raw_image[..., ::-1].copy())
    return image, raw_image

def get_classtable_cub():
    classes = []
    with open("data/cub200/classes.txt") as lines:
        for line in lines:
            line = line.strip().split(" ", 1)[1]
            line = line.split(".", 1)[1]
            classes.append(line)
    return classes

def load_dataset(opt):
    normalize_transform = transforms.Compose([transforms.ToTensor(),
                                              transforms.Normalize((0.485, 0.456, 0.406),
                                                                   (0.229, 0.224, 0.225))])
    val_large_transform = transforms.Compose([transforms.Resize(256),
                                              transforms.CenterCrop(224)])
    val_transform = transforms.Compose([val_large_transform, normalize_transform])
    dataset = dset.ImageFolder(root=os.path.join(opt.dataroot, 'test'), transform=val_transform)
    sub = [inds for inds, e in enumerate(dataset.targets) if e == opt.image_class_tg]
    subset = torch.utils.data.Subset(dataset, sub)
    loaders = torch.utils.data.DataLoader(subset,
                                           batch_size=len(sub),
                                           shuffle=False,
                                           num_workers=0)
    return loaders

def save_gradcam(filename, gcam, raw_image, paper_cmap=False):
    gcam = gcam.cpu().numpy()
    np.save('gcam', gcam)
    np.save('raw_image', raw_image)
    cmap = cm.jet_r(gcam)[..., :3] * 255.0
    if paper_cmap:
        alpha = gcam[..., None]
        gcam = alpha * cmap + (1 - alpha) * raw_image
    else:
        gcam = (cmap.astype(np.float) + raw_image.astype(np.float)) / 2
    cv2.imwrite(filename, np.uint8(gcam))


def validate(model, source_model, opt):
    model.eval()
    for x, y in opt.loaders:
        x, y = x.to(opt.device), y.to(opt.device)
        if(opt.target_class == 'indep'):
            y_hat, _ = model(x)
        else:
            with torch.no_grad():
                source_out, source_features = source_model(x)
            y_hat, _ = model(x, source_features, opt.state_pairs)
        y_pred = torch.argmax(y_hat, dim=1).detach()
    return y_pred.numpy()


def main():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--dataroot', required=True, help='Path to the dataset')
    parser.add_argument('--dataset', default='cub200')
    parser.add_argument('--datasplit', default='cub200')
    parser.add_argument('--datanoise', action='store_true', default=False)
    parser.add_argument('--seed', type=int, default=0)

    parser.add_argument('--source-model', default='resnet34', type=str)
    parser.add_argument('--source-domain', default='imagenet', type=str)
    parser.add_argument('--source-path', type=str, default=None)
    parser.add_argument('--pretrained-src-model', action='store_true', default=False)
    parser.add_argument('--target-model', default='resnet18', type=str)
    parser.add_argument('--pretrained-tg-model', action='store_true', default=False)
    parser.add_argument('--numTrain', type=int, default=100, help='Train sample size. 100 means use all')

    parser.add_argument('--transfer-type', default='xstitch', type=str,
                        choices=['indep','spottune', 'xstitch', 'routenorm','transnorm', 'simpleblock','linstitch', 'combine', 'shared', 'block', 'attention'],
                        help='Different options to combine the source and target')
    # default settings
    opt = parser.parse_args()

    opt.source_feature_ids = [0, 1, 2, 3, 4]
    opt.target_decisioner_ids = [0, 1, 2, 3]
    opt.source_input_pass_id = 5  # ID corresponding to PASS/skip action
    feat2id = {i: ids for i, ids in
               enumerate(opt.source_feature_ids + [opt.source_input_pass_id])}  # maps arm d to the module
    decisioner2id = {i: ids for i, ids in
               enumerate(opt.target_decisioner_ids)}  # maps decisioner to the module
    opt.narms = len(opt.source_feature_ids) + 1
    opt.transfer_types = [opt.transfer_type for _ in opt.source_feature_ids]

    # Initialize Decision Maker
    decisioners = [None for _ in range(4)]
    opt.source_info = (opt.source_feature_ids, opt.source_input_pass_id, decisioners)

    random.seed(opt.seed)
    opt.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    models = ['indep', 'xstitch', 'bandit']
    opt.image_class_tg = 13
    opt.image_class_src = 14
    opt.state_pairs=[]

    opt.loaders = load_dataset(opt)
    opt.num_target_classes = len(opt.loaders.dataset.dataset.classes)
    preds = []
    opt.target_class = 'source'
    source_model = choose_model(opt).to(opt.device)
    for model in models:
        opt.target_class = model
        target_model = choose_model(opt).to(opt.device)
        preds.append(validate(target_model, source_model, opt))
        del target_model
    roi_inds = np.where(np.logical_and(preds[0] != opt.image_class_tg, preds[2]==opt.image_class_tg))[0].tolist()
    roi_data = [opt.loaders.dataset.dataset.imgs[inds] for inds in opt.loaders.dataset.indices]
    imgs = [roi_data[inds][0] for inds in roi_inds]
    print(preds)

    #target_layers = ["layer1", "layer2", "layer3", "layer4"]
    target_layers = ["layer4"]
    classes = get_classtable_cub()

    models = ['bandit']
    for model in models:
        opt.target_class = model
        target_model = choose_model(opt).to(opt.device)
        print(opt.state_pairs)
        target_model.eval()

        if (target_model=='source'):
            target_class = opt.image_class_src
        else:
            target_class = opt.image_class_tg

        for image in imgs:
            images, raw_images = load_images([image])
            images = torch.stack(images).to(opt.device)
            if(model=='indep' or model=='source'):
                gcam = GradCAM(model=target_model)
            elif(model=='bandit' or model=='xstitch'):
                gcam = GradCAMATL(model=target_model, source_model=source_model, pairs=opt.state_pairs)
            output_dir = os.path.join('results', model, Path(image).stem)
            Path(output_dir).mkdir(parents=True, exist_ok=True)

            probs, ids = gcam.forward(images)
            ids_ = torch.LongTensor([[target_class]] * len(images)).to(opt.device)
            gcam.backward(ids=ids_)
            print("Top 5 class is @{}".format(ids[0][:5]))
            for target_layer in target_layers:
                print("Generating Grad-CAM @{}".format(target_layer))
            # Grad-CAM
                regions = gcam.generate(target_layer=target_layer)
                for j in range(len(images)):
                    print(
                        "\t#{}: {} ({:.5f})".format(
                            j, classes[target_class], float(probs[ids == target_class])
                        )
                    )
                    save_gradcam(
                        filename=osp.join(
                            output_dir,
                            "{}-gradcam-{}-{}.png".format(
                                j, target_layer, classes[target_class]
                            ),
                        ),
                        gcam=regions[j, 0],
                        raw_image=raw_images[j],
                    )
if __name__ == '__main__':
    main()
