import json
import torch
# from transformers import ViltProcessor, ViltForQuestionAnswering, ViltConfig
import os
import argparse
# import evaluate
from torch.utils.data import DataLoader
from torch import nn, optim
from transformers import get_scheduler
from tqdm.auto import tqdm
import wandb
from sklearn.metrics import accuracy_score
from transformers import ViTImageProcessor, AutoModel, AutoConfig
import torch.nn.functional as F
import torchvision.datasets as datasets
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
import sys
print(Path(__file__).parents[0])
print(Path(__file__).parents[1])
path_root = Path(__file__).parents[1]
print(path_root)
sys.path.append(str(path_root))
import numpy as np
import random
from tqdm import tqdm
from PIL import Image
import pickle
import matplotlib.pyplot as plt
from skimage.transform import resize
import matplotlib


def show_anns(masks, ax):
    if len(masks) == 0:
        return
    # sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    # ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for mask in masks:
        # m = ann['segmentation']
        m = resize(mask, (224, 224), order=1, mode='reflect', anti_aliasing=False)
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))

         # Draw boundary lines
        boundary_mask = np.zeros_like(m, dtype=bool)
        for i in range(1, m.shape[0] - 1):
            for j in range(1, m.shape[1] - 1):
                if m[i, j]:
                    if not m[i-1, j] or not m[i+1, j] or not m[i, j-1] or not m[i, j+1]:
                        boundary_mask[i, j] = True
        ax.imshow(np.dstack((boundary_mask[..., None] * np.array([0, 0, 1]), boundary_mask * 0.7)))


def parse_args():
    parser = argparse.ArgumentParser()

    # paths and info
    parser.add_argument('--input-dir', type=str, 
                        default='exps/cosmogrid_wrapper_lr00005_psz1/best/val_results', 
                        help='input dir')
    parser.add_argument('--output-dir', type=str, 
                        default='exps/cosmogrid_wrapper_lr00005_psz1/best/plots', 
                        help='output dir')
    parser.add_argument('--num-masks-show', type=str, 
                        default=5, 
                        help='show masks with top 5 scores for each prediction')
    parser.add_argument('--no-plot', 
                        default=False, 
                        action='store_true',
                        help='do not plot')
    
    return parser

def main():
    parser = parse_args()
    args = parser.parse_args()

    print('\n---argparser---:')
    for arg in vars(args):
        print(arg, getattr(args, arg), '\t', type(arg))

    os.makedirs(args.input_dir, exist_ok=True)
    os.makedirs(args.output_dir, exist_ok=True)

    total_nonzeros = 0
    total_count = 0

    for filename in tqdm(os.listdir(args.input_dir)):
        input_filename = os.path.join(args.input_dir, filename)
        if os.path.isdir(input_filename) or not input_filename.endswith('.pkl'):
            continue
        output_filename = os.path.join(args.output_dir, str(filename).replace('.pkl', '.png'))
        with open(input_filename, 'rb') as input_file:
            data = pickle.load(input_file)
            # attn_weights1_avg: (img_dim1, img_dim2)
            # entry = {'image': images[i],
            #         'attn_weights1_avg': attn_weights1_avg[i][:,:,labels[i]],
            #         'pred': predicted[i].item(),
            #         'label': labels[i].item()}
            image = data['image']
            # id2label_long = data['id2label']
            # id2label = {}
            # for k, v in id2label_long.items():
            #     id2label[k] = v.split(',')[0]
            image_array = np.array(image)[0]
            masks_used = data['masks_used'].cpu().numpy()
            masks = data['masks']
            masks_all = data['masks_all']
            counts = data['counts']
            # import pdb
            # pdb.set_trace()
            max_idxs = torch.argsort(data["mask_weights"], 
                                     dim=0).flip(dims=tuple([0]))[:args.num_masks_show]
            max_idxs = max_idxs.view(-1).tolist()
            def remove_duplicates(input_list):
                return list(dict.fromkeys(input_list))
            max_idxs = remove_duplicates(max_idxs)

            if args.no_plot:
                for idx in range(len(max_idxs)):
                    j = max_idxs[idx]
                    nonzeros = torch.tensor(masks_used[j]).bool().int().sum().item()
                    total_nonzeros += nonzeros / torch.tensor(masks_used[j]).view(-1).shape[0]
                    total_count += 1
                continue
            
            fig, axs = plt.subplots(2, len(max_idxs) + 1, figsize=(5 * (len(max_idxs) + 1), 10))
            # import pdb
            # pdb.set_trace()
            default_cmap = matplotlib.rcParams['image.cmap']

            # print(default_cmap)
            cmap = plt.get_cmap(default_cmap) # plt.get_cmap('jet')
            norm = plt.Normalize()
            # colors = plt.cm.jet(norm(dz))
            rgb_image = cmap(norm(image_array))
            # import pdb
            # pdb.set_trace()
            axs[0][0].imshow(rgb_image)
            if masks_all is not None:
                axs[1][0].imshow(masks_all)
            else:
                axs[1][0].imshow(rgb_image)
            # probs_avg = torch.softmax(data["outputs_avg"], dim=-1)
            # probs = torch.softmax(data["outputs"], dim=-1)
            # import pdb
            # pdb.set_trace()
            probs_avg = data["outputs_avg"]
            probs = data["outputs"]
            outputs_original = data["outputs_original"]
            # import pdb
            # pdb.set_trace()
            pred_round = [round(x, 4) for x in data["pred"]]
            label_round = [round(x, 4) for x in data["label"]]
            axs[0][0].set_title(f'Label: {label_round}\n' # Prediction: {pred_round}\n' +
                            #  f'Logits: {[round(p, 4) for p in probs_avg[data["pred"]].tolist()]}\n' +
                             f'Wrapped Pred:  {[round(p, 4) for p in probs_avg.tolist()]}\n' +
                             f'Original Pred: {[round(p, 4) for p in outputs_original.tolist()]}\n' +
                             f'Loss: {pred_round}\n')
            
            image_arr = image_array
            # axs[1].imshow(image_arr)
            # show_anns(data['masks'], axs[1])
            # import pdb
            # pdb.set_trace()
            for idx in range(len(max_idxs)):
                j = max_idxs[idx]
                # axs[j + 2].imshow(image_arr)
                # axs[j + 2].imshow(masks_used[j], cmap='jet', alpha=0.5)
                try:
                    # axs[idx + 2].imshow(image_arr * masks_used[j])
                    axs[0][idx + 1].imshow(rgb_image * np.expand_dims(masks_used[j], axis=-1))
                except:
                    import pdb
                    pdb.set_trace()
                # import pdb
                # pdb.set_trace()
                mask_score = [round(x, 4) for x in data["mask_weights"][j].tolist()]
                pred_j = [round(x, 4) for x in probs[j].tolist()]
                loss_j = [round(x, 4) for x in data['preds'][j]]
                nonzeros = torch.tensor(masks_used[j]).bool().int().sum().item()
                axs[0][idx + 1].set_title(f'Prediction: {pred_j}\n' +
                                       f'Loss: {loss_j}\n' +
                                     f'Mask score: {mask_score}\n' + 
                                     f'Count: {counts[j]}\n' + 
                                     f'Non-zero: {nonzeros}')
            
                axs[1][idx + 1].imshow(masks_used[j], cmap='gray')
                # axs[1][idx + 1].set_title(f'Prediction: {pred_j}\n' +
                #                        f'Loss: {loss_j}\n' +
                #                      f'Mask score: {mask_score}\n' + 
                #                      f'Count: {counts[j]}\n' + 
                #                      f'Non-zero: {torch.tensor(masks_used[j]).bool().int().sum().item()}')
                total_nonzeros += nonzeros / torch.tensor(masks_used[j]).view(-1).shape[0]
                total_count += 1
            # import pdb
            # pdb.set_trace()
            plt.subplots_adjust(top=0.8, hspace=1)
            plt.savefig(output_filename)
            plt.close()
    
    print('avg non-zeros', total_nonzeros / total_count)

if __name__ == '__main__':
    main()