import json
import torch
# from transformers import ViltProcessor, ViltForQuestionAnswering, ViltConfig
import os
import argparse
# import evaluate
from torch.utils.data import DataLoader
from torch import nn, optim
from transformers import get_scheduler
from tqdm.auto import tqdm
import wandb
from sklearn.metrics import accuracy_score
from transformers import ViTImageProcessor, AutoModel, AutoConfig
import torch.nn.functional as F
import torchvision.datasets as datasets
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
import sys
print(Path(__file__).parents[0])
print(Path(__file__).parents[1])
path_root = Path(__file__).parents[1]
print(path_root)
sys.path.append(str(path_root))
import numpy as np
import random
from tqdm import tqdm
from PIL import Image
import pickle
import matplotlib.pyplot as plt
from skimage.transform import resize
import matplotlib


def show_anns(masks, ax):
    if len(masks) == 0:
        return
    # sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    # ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for mask in masks:
        # m = ann['segmentation']
        m = resize(mask, (224, 224), order=1, mode='reflect', anti_aliasing=False)
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))

         # Draw boundary lines
        boundary_mask = np.zeros_like(m, dtype=bool)
        for i in range(1, m.shape[0] - 1):
            for j in range(1, m.shape[1] - 1):
                if m[i, j]:
                    if not m[i-1, j] or not m[i+1, j] or not m[i, j-1] or not m[i, j+1]:
                        boundary_mask[i, j] = True
        ax.imshow(np.dstack((boundary_mask[..., None] * np.array([0, 0, 1]), boundary_mask * 0.7)))


def parse_args():
    parser = argparse.ArgumentParser()

    # paths and info
    parser.add_argument('--input-dir', type=str, 
                        default='exps/cosmogrid_wrapper_lr00005_psz1/best/val_results', 
                        help='input dir')
    parser.add_argument('--output-dir', type=str, 
                        default='exps/cosmogrid_wrapper_lr00005_psz1/best/plots', 
                        help='output dir')
    parser.add_argument('--num-masks-show', type=str, 
                        default=5, 
                        help='show masks with top 5 scores for each prediction')
    
    return parser

def main():
    parser = parse_args()
    args = parser.parse_args()

    print('\n---argparser---:')
    for arg in vars(args):
        print(arg, getattr(args, arg), '\t', type(arg))

    os.makedirs(args.input_dir, exist_ok=True)
    os.makedirs(args.output_dir, exist_ok=True)

    for filename in tqdm(os.listdir(args.input_dir)):
        input_filename = os.path.join(args.input_dir, filename)
        if os.path.isdir(input_filename):
            continue
        output_filename = os.path.join(args.output_dir, str(filename).replace('.pkl', '.png'))
        with open(input_filename, 'rb') as input_file:
            data = pickle.load(input_file)
            # attn_weights1_avg: (img_dim1, img_dim2)
            # entry = {'image': images[i],
            #         'attn_weights1_avg': attn_weights1_avg[i][:,:,labels[i]],
            #         'pred': predicted[i].item(),
            #         'label': labels[i].item()}
            image = data['image']
            id2label_long = data['id2label']
            id2label = {}
            for k, v in id2label_long.items():
                id2label[k] = v.split(',')[0]
            image_array = np.array(image)
            masks_used = data['masks_used'].cpu().numpy()
            masks = data['masks']
            counts = data['counts']
            # import pdb
            # pdb.set_trace()
            max_idxs = torch.argsort(data["mask_weights"], 
                                     dim=0).flip(dims=tuple([0]))[:args.num_masks_show]
            max_idxs = max_idxs.view(-1).tolist()
            def remove_duplicates(input_list):
                return list(dict.fromkeys(input_list))
            max_idxs = remove_duplicates(max_idxs)
            
            fig, axs = plt.subplots(2, len(max_idxs) + 1, figsize=(5 * (len(max_idxs) + 1), 10))
            # import pdb
            # pdb.set_trace()
            # default_cmap = matplotlib.rcParams['image.cmap']

            # # print(default_cmap)
            # cmap = plt.get_cmap(default_cmap) # plt.get_cmap('jet')
            # norm = plt.Normalize()
            # colors = plt.cm.jet(norm(dz))
            # rgb_image = cmap(norm(image_array))
            # import pdb
            # pdb.set_trace()
            image_arr = resize(image_array, (224, 224), order=1, mode='reflect', anti_aliasing=False)
            axs[0][0].imshow(image_arr)
            axs[1][0].imshow(image_array)
            probs_avg = torch.softmax(data["outputs_avg"], dim=-1)
            probs = torch.softmax(data["outputs"], dim=-1)
            # import pdb
            # pdb.set_trace()
            # probs_avg = data["outputs_avg"]
            # probs = data["outputs"]
            # import pdb
            # pdb.set_trace()
            # print('data["pred"]', data["pred"])
            # pred_round = [round(x, 4) for x in data["pred"]]
            # label_round = [round(x, 4) for x in data["label"]]
            # axs[0][0].set_title(f'Label: {label_round}\n' # Prediction: {pred_round}\n' +
            #                 #  f'Logits: {[round(p, 4) for p in probs_avg[data["pred"]].tolist()]}\n' +
            #                  f'Prediction: {[round(p, 4) for p in probs_avg.tolist()]}\n' +
            #                  f'Loss: {pred_round}\n')
            axs[0][0].set_title(f'Label: {id2label[data["label"]]}\nPrediction: {id2label[data["pred"]]}\n' +
                             f'Probability: {round(probs_avg[data["pred"]].item(), 4)}\n' +
                             f'Logits: {round(data["outputs_avg"][data["pred"]].item(), 4)}')
            
            # image_arr = image_array
            # axs[1].imshow(image_arr)
            # show_anns(data['masks'], axs[1])
            # import pdb
            # pdb.set_trace()
            for idx in range(len(max_idxs)):
                j = max_idxs[idx]
                # axs[j + 2].imshow(image_arr)
                # axs[j + 2].imshow(masks_used[j], cmap='jet', alpha=0.5)
                # print('image_arr', image_arr.shape)
                # print('masks_used[j]', masks_used[j].shape)
                axs[0][idx + 1].imshow(image_arr * np.expand_dims(masks_used[j], axis=-1))
                # import pdb
                # pdb.set_trace()
                # mask_score = [round(x, 4) for x in data["mask_weights"][j].tolist()]
                # pred_j = [round(x, 4) for x in probs[j].tolist()]
                # loss_j = [round(x, 4) for x in data['preds'][j]]
                # import pdb
                # pdb.set_trace()
                mask_score_self_pred = round(data["mask_weights"][j][data["preds"][j]].item(), 4)
                mask_score_aggr_pred = round(data["mask_weights"][j][data["pred"]].item(), 4)
                # import pdb
                # pdb.set_trace()
                axs[0][idx + 1].set_title(f'Prediction: {id2label[data["preds"][j]]}\n' +
                                    f'Probability: {round(probs[j][data["preds"][j]].item(), 4)}\n' +
                                    f'Logits: {round(data["outputs"][j][data["preds"][j]].item(), 4)}\n' +
                                    f'Mask score self pred: {mask_score_self_pred}\n' +
                                    f'Mask score aggr pred: {mask_score_aggr_pred}\n' +
                                    f'Count: {counts[j]}\n' +
                                    f'Non-zero: {torch.tensor(masks_used[j]).bool().int().sum().item()}')
                # axs[0][idx + 1].set_title(f'Prediction: {pred_j}\n' +
                #                        f'Loss: {loss_j}\n' +
                #                      f'Mask score: {mask_score}\n' + 
                #                      f'Count: {counts[j]}\n' + 
                #                      f'Non-zero: {torch.tensor(masks_used[j]).bool().int().sum().item()}')
            
                axs[1][idx + 1].imshow(masks_used[j], cmap='gray')
                # axs[1][idx + 1].set_title(f'Prediction: {pred_j}\n' +
                #                        f'Loss: {loss_j}\n' +
                #                      f'Mask score: {mask_score}\n' + 
                #                      f'Count: {counts[j]}\n' + 
                #                      f'Non-zero: {torch.tensor(masks_used[j]).bool().int().sum().item()}')
            # import pdb
            # pdb.set_trace()
            plt.subplots_adjust(top=0.8, hspace=1)
            plt.savefig(output_filename)
            plt.close()


if __name__ == '__main__':
    main()