import copy
import os.path as osp
import argparse
import click
import cv2
import matplotlib.cm as cm
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.datasets as dset
import random
from torchvision import models, transforms
from check_dataset import check_dataset
from check_model import check_model
from utils import resnet_icml_ilsvrc


from grad_cam import (
    BackPropagation,
    Deconvnet,
    GradCAM,
    GradCAMATL,
    GuidedBackPropagation,
    occlusion_sensitivity,
)

import os
import sys
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from torchvision import transforms as T

def choose_model(option, opt):
    if (option == 'source'):
        model = resnet_icml_ilsvrc.__dict__[opt.source_model](pretrained=opt.pretrained_src_model)
    elif (option == 'indep'):
        model = resnet_icml_ilsvrc.__dict__[opt.target_model](num_classes=opt.num_target_classes, pretrained=opt.pretrained_tg_model).to(opt.device)
        model.load_state_dict(torch.load('models/cub200/indep_best.pth', map_location=torch.device(opt.device))['target_model'])
    elif (option == 'xstitch'):
        opt.ru_units=False
        model = resnet_icml_ilsvrc.__dict__[opt.target_model](num_classes=opt.num_target_classes, pretrained=opt.pretrained_tg_model,
                                                          transfer_types=opt.transfer_types,
                                                          source_info=opt.source_info,
                                                          ru_units=opt.ru_units)
        ckpt = torch.load('models/cub200/xstitch_best.pth', map_location=torch.device(opt.device))
        opt.state_pairs = ckpt['pairs']
        model.load_state_dict(ckpt['target_model'])
    elif (option == 'bandit'):
        opt.ru_units=True
        model = resnet_icml_ilsvrc.__dict__[opt.target_model](num_classes=opt.num_target_classes, pretrained=opt.pretrained_tg_model,
                                                          transfer_types=opt.transfer_types,
                                                          source_info=opt.source_info,
                                                          ru_units=opt.ru_units)
        ckpt = torch.load('models/cub200/bandit_best.pth', map_location=torch.device(opt.device))
        opt.state_pairs = ckpt['pairs']
        model.load_state_dict(ckpt['target_model'])
    return model

def load_images(image_paths):
    images = []
    raw_images = []
    print("Images:")
    for i, image_path in enumerate(image_paths):
        print("\t#{}: {}".format(i, image_path))
        image, raw_image = preprocess(image_path)
        images.append(image)
        raw_images.append(raw_image)
    return images, raw_images

def preprocess(image_path):
    raw_image = cv2.imread(image_path)
    raw_image = cv2.resize(raw_image, (224,) * 2)
    image = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )(raw_image[..., ::-1].copy())
    return image, raw_image

def get_classtable_cub():
    classes = []
    with open("data/cub200/classes.txt") as lines:
        for line in lines:
            line = line.strip().split(" ", 1)[1]
            line = line.split(".", 1)[1]
            classes.append(line)
    return classes

def load_dataset(opt):
    normalize_transform = transforms.Compose([transforms.ToTensor(),
                                              transforms.Normalize((0.485, 0.456, 0.406),
                                                                   (0.229, 0.224, 0.225))])
    val_large_transform = transforms.Compose([transforms.Resize(256),
                                              transforms.CenterCrop(224)])
    val_transform = transforms.Compose([val_large_transform, normalize_transform])
    dataset = dset.ImageFolder(root=os.path.join(opt.dataroot, 'test'), transform=val_transform)
    sub = [inds for inds, e in enumerate(dataset.targets) if e == opt.image_class_tg]
    subset = torch.utils.data.Subset(dataset, sub)
    loaders = torch.utils.data.DataLoader(subset,
                                           batch_size=len(sub),
                                           shuffle=False,
                                           num_workers=0)
    return loaders

def save_gradcam(filename, gcam, raw_image, paper_cmap=False):
    gcam = gcam.cpu().numpy()
    cmap = cm.jet_r(gcam)[..., :3] * 255.0
    if paper_cmap:
        alpha = gcam[..., None]
        gcam = alpha * cmap + (1 - alpha) * raw_image
    else:
        gcam = (cmap.astype(np.float) + raw_image.astype(np.float)) / 2
    cv2.imwrite(filename, np.uint8(gcam))


def validate(model, source_model, loader, opt):
    model.eval()
    for x, y in loader:
        x, y = x.to(opt.device), y.to(opt.device)
        if(opt.target_class == 'indep'):
            y_hat, _ = model(x)
        else:
            with torch.no_grad():
                source_out, source_features = source_model(x)
            y_hat, _ = model(x, source_features, opt.state_pairs)
        y_pred = torch.argmax(y_hat, dim=1).detach()
    return y_pred.numpy()


def main():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--dataroot', required=True, help='Path to the dataset')
    parser.add_argument('--dataset', default='cub200')
    parser.add_argument('--datasplit', default='cub200')
    parser.add_argument('--datanoise', action='store_true', default=False)
    parser.add_argument('--seed', type=int, default=0)

    parser.add_argument('--source-model', default='resnet34', type=str)
    parser.add_argument('--source-domain', default='imagenet', type=str)
    parser.add_argument('--source-path', type=str, default=None)
    parser.add_argument('--pretrained-src-model', action='store_true', default=False)
    parser.add_argument('--target-model', default='resnet18', type=str)
    parser.add_argument('--pretrained-tg-model', action='store_true', default=False)
    parser.add_argument('--numTrain', type=int, default=100, help='Train sample size. 100 means use all')

    parser.add_argument('--transfer-type', default='xstitch', type=str,
                        choices=['indep','spottune', 'xstitch', 'routenorm','transnorm', 'simpleblock','linstitch', 'combine', 'shared', 'block', 'attention'],
                        help='Different options to combine the source and target')
    # default settings
    opt = parser.parse_args()

    opt.source_feature_ids = [0, 1, 2, 3, 4]
    opt.target_decisioner_ids = [0, 1, 2, 3]
    opt.source_input_pass_id = 5  # ID corresponding to PASS/skip action
    feat2id = {i: ids for i, ids in
               enumerate(opt.source_feature_ids + [opt.source_input_pass_id])}  # maps arm d to the module
    decisioner2id = {i: ids for i, ids in
               enumerate(opt.target_decisioner_ids)}  # maps decisioner to the module
    opt.narms = len(opt.source_feature_ids) + 1
    opt.transfer_types = [opt.transfer_type for _ in opt.source_feature_ids]

    # Initialize Decision Maker
    decisioners = [None for _ in range(4)]
    opt.source_info = (opt.source_feature_ids, opt.source_input_pass_id, decisioners)

    random.seed(opt.seed)
    opt.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    models = ['indep','xstitch',  'bandit']
    opt.image_class_tg = 13
    opt.image_class_src = 14
    opt.state_pairs=[]

    loaders = load_dataset(opt)
    print(opt)
    source_model = choose_model('source', opt).to(opt.device)
    opt.num_target_classes = len(loaders.dataset.dataset.classes)
    preds = []
    for model in models:
        opt.target_class = model
        target_model = choose_model(opt.target_class, opt).to(opt.device)
        #print(target_model)
        preds.append(validate(target_model, source_model, loaders, opt))
    roi_inds = np.where(np.logical_and(preds[0] != opt.image_class_tg, preds[2]==opt.image_class_tg))[0].tolist()
    roi_data = [loaders.dataset.dataset.imgs[inds] for inds in loaders.dataset.indices]
    imgs = [roi_data[inds][0] for inds in roi_inds]
    print(preds)

    classes = get_classtable_cub()

    if (target_model=='source'):
        target_class = opt.image_class_src
    else:
        target_class = opt.image_class_tg

    models = ['bandit']
    source_model = choose_model('source', opt).to(opt.device)
    for model in models:
        opt.target_class = model
        target_model = choose_model(opt.target_class, opt).to(opt.device)
        target_model.eval()
        for image in imgs:
            images, raw_images = load_images([image])
            images = torch.stack(images).to(opt.device)

            images.requires_grad_()
            output_dir = os.path.join('results', model, Path(image).stem)
            Path(output_dir).mkdir(parents=True, exist_ok=True)

            if model == 'indep':
                scores, _ = target_model(images)
            else:
                with torch.no_grad():
                    source_out, source_features = source_model(images)
                scores, _ = target_model(images, source_features, opt.state_pairs)
            score_max_index = scores.argmax(dim=1)
            score_max = scores[0,score_max_index]
            score_max.backward()
            saliency, _ = torch.max(images.grad.data.abs(),dim=1)
            plt.imsave(output_dir + '/saliency.png', saliency[0], cmap='gray')

if __name__ == '__main__':
    main()
