import os
import ast
import yaml
import argparse
import random
from functools import partial

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import mlflow
import numpy as np
import matplotlib.pyplot as plt

import dataset
from models.network import get_network


def fix_seed(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    cudnn.deterministic = True
    cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)


class make_args:
    def __init__(self, **entries):
        self.__dict__.update(entries)


def download_args(run_id):
    params = mlflow.get_run(run_id).data.params
    temp = {}
    for i, j in params.items():
        try:
            temp[i] = ast.literal_eval(j)
        except:
            temp[i] = j
    return temp


def _parse_args(cfg):
    return make_args(**cfg)


def expand2img(x, b):
    return torch.Tensor(x).reshape(
        1, 3, 1, 1).expand(b, 3, 224, 224)


def min_max(AA):
    b, h, w = AA.shape
    AA = AA.view(AA.size(0), -1)
    AA -= AA.min(1, keepdim=True)[0]
    AA /= AA.max(1, keepdim=True)[0]
    AA = AA.view(b, h, w)
    return AA


def draw_attn(args, img, attn, mode='all'):
    with torch.no_grad():
        attn = attn.softmax(dim=-1)
        patch_size = 16
        w_featmap = img.shape[-2] // patch_size
        h_featmap = img.shape[-1] // patch_size
        b, nh, _ = attn.shape
        attn = attn.reshape(b, 6, w_featmap, h_featmap)
        attn = nn.functional.interpolate(
            attn, scale_factor=patch_size,
            mode="bicubic", align_corners=True)
        attn_cat = torch.moveaxis(attn, 1, 3)
        if mode == 'all':
            attn_cat = torch.cat(
                [min_max(attn_cat[:, :, :, i]) for i in range(6)], dim=2)
            attn_cat = torch.stack(([attn_cat]*3), dim=-1).cpu().detach().numpy()
            attn = torch.cat(
                [min_max(attn[:, i, :, :]).unsqueeze(1) for i in range(6)],
                dim=1)
            attn = torch.stack(([attn]*3), dim=-1).cpu().detach().numpy()
        else:
            attn_cat = min_max(attn_cat.max(dim=-1))
            attn = attn_cat
        mean = expand2img(args.mean, b).cuda()
        std = expand2img(args.std, b).cuda()
        img1 = torch.moveaxis(img*std+mean, 1, 3).cpu().detach()
        return img1, attn_cat, attn


def attn_color_remap(attn_cat, attn):
    company_colors = [
        (0, 160, 215), # blue
        (220, 55, 60), # red
        (245, 180, 0), # yellow
        (10, 120, 190), # navy
        (40, 150, 100), # green
        (135, 75, 145), # purple
        ]
    company_colors = [(float(c[0]) / 255.0, float(c[1]) / 255.0, float(c[2]) / 255.0)
                      for c in company_colors]
    colors = company_colors[:attn.shape[1]]
    cat_colors = plt.get_cmap('viridis').colors
    for i in range(attn.shape[1]):
        attn[:,i] = attn[:,i] * colors[i]
    attn_sum = np.max(attn,axis=1)
    attn_cat_shape = attn_cat.shape
    attn_cat = attn_cat.reshape([-1,3])
    for i in range(attn_cat.shape[0]):
        attn_cat[i] = cat_colors[int(attn_cat[i][0]*255)]
    attn_cat = attn_cat.reshape(attn_cat_shape)
    return attn_cat, attn_sum


def mix_img_attn(img, attn, alpha=0.2):
    return alpha * img + (1-alpha) * attn


def save(args, ori, attn_o, attn, name):
    fig, ax = plt.subplots(5, 2, figsize=(25, 10))
    for i, (o, oa, a) in enumerate(zip(ori, attn_o, attn)):
        if i > 4:
            ax[i-5, 1].axis('off')
            ax[i-5, 1].imshow(np.concatenate((o, oa, a), axis=-2))
        else:
            ax[i, 0].axis('off')
            ax[i, 0].imshow(np.concatenate((o, oa, a), axis=-2))
    fig.tight_layout()
    fig.savefig(f"{name}")
    plt.clf()
    plt.close()

def save_all(args, ori, attn_o, attn, attn_cat, output, target, name):
    fig, ax = plt.subplots(5, 2, figsize=(25, 10))
    outputs = torch.argmax(output,dim=-1)
    attn_cat = np.concatenate([attn, attn_cat], axis=2)
    for i, (o,a,ac) in enumerate(zip(ori,attn_o,attn_cat)):
        if i > 4:
            ax[i-5, 1].axis('off')
            ax[i-5, 1].title.set_text(f'prediction:{outputs[i]}   label:{target[i]}')
            ax[i-5, 1].imshow(np.concatenate((o, a, ac), axis=-2))
        else:
            ax[i, 0].axis('off')
            ax[i, 0].title.set_text(f'prediction:{outputs[i]}   label:{target[i]}')
            ax[i, 0].imshow(np.concatenate((o, a, ac), axis=-2))
    fig.tight_layout()
    fig.savefig(f"{name}")
    plt.clf()
    plt.close()


def validate(model, loader, args):
    model.eval()
    if 'dino' in args.__dict__.keys() and args.dino:
        args.dino = True
    else:
        args.dino = False
    path = os.path.join('viz_attn', f"{args.experiments_subname}_{args.tag}")
    os.makedirs(path, exist_ok=True)
    with torch.no_grad():
        for batch_idx, (inputs, target) in enumerate(loader):
            if batch_idx >= args.num_samples:
                break
            inputs = inputs.cuda()
            output = model(inputs)
            ori, attn_cat, attn = draw_attn(args, inputs, output[1][:, -1, :, 1:])
            attn_cat, attn = attn_color_remap(attn_cat, attn)
            attn_ = mix_img_attn(ori, attn, alpha=0.2)
            # save(args, ori, attn_, attn, f'{path}/{batch_idx:03d}.pdf')
            save_all(args, ori, attn_, attn, attn_cat, output[0], target, f'{path}/{batch_idx:03d}.jpg')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--path',
        default=None)
    parser.add_argument(
        '--num_samples',
        default=2,
        type=int,
        help='number of output samples (defulat: 2)')

    args_ = parser.parse_args()
    MODEL_PATH = args_.path
    PATH = download_args(args_.path.split('/')[-3])
    args = _parse_args(PATH)
    args.num_samples = args_.num_samples
    fix_seed(234)
    model = get_network(args)
    weight = mlflow.pytorch.load_model(MODEL_PATH, map_location='cpu')
    model.load_state_dict(weight.state_dict(), strict=False)
    del weight
    model = model.cuda()
    fix_seed(234)
    args.batch_size = 10
    args.baseline = True
    transform_train, transform_val = dataset.make_aug(args)
    train_set, valid_set = dataset.make_dataset(args, transform_train, transform_val)
    loader_eval = DataLoader(
        valid_set,
        pin_memory=True,
        batch_size=args.batch_size,
        sampler=None,
        shuffle=True,
        worker_init_fn=partial(
            dataset._worker_init, worker_seeding='all'),
        persistent_workers=True,
        num_workers=args.workers)
    validate(model, loader_eval, args)
