from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import logging
import matplotlib.pyplot as plt
import argparse
import math
import os

from typing import Dict, List
from scipy import signal
from scipy.cluster.vq import kmeans2
from sklearn import cluster
from sklearn.manifold import TSNE
from kmeans_pytorch import kmeans
from k_means_constrained import KMeansConstrained
from scipy.cluster import hierarchy as hcluster


parser = argparse.ArgumentParser(
    description="Discovering regularities over events of low-frequency"
)
parser.add_argument(
    '--dataset', type=str,
    help="Dataset name"
)
models = [
    'ComplEx', 'TComplEx', 'TNTComplEx', 'NE', 'search', 'DE_SimplE', 'DE_DistMult', 'DE_TransE', 'TeRo', 'ATISE', 
]
parser.add_argument(
    '--model', choices=models,
    help="Model in {}".format(models)
)
parser.add_argument(
    '--max_epochs', default=50, type=int,
    help="Number of epochs."
)
parser.add_argument(
    '--valid_freq', default=5, type=int,
    help="Number of epochs between each valid."
)
parser.add_argument(
    '--rank', default=100, type=int,
    help="Factorization rank."
)
parser.add_argument(
    '--batch_size', default=1000, type=int,
    help="Batch size."
)
parser.add_argument(
    '--learning_rate', default=1e-1, type=float,
    help="Learning rate"
)
parser.add_argument(
    '--emb_reg', default=0., type=float,
    help="Embedding regularizer strength"
)
parser.add_argument(
    '--time_reg', default=0., type=float,
    help="Timestamp regularizer strength"
)
parser.add_argument(
    '--no_time_emb', default=False, action="store_true",
    help="Use a specific embedding for non temporal relations"
)
parser.add_argument(
    '--time_rank', default=100, type=int,
    help="Time rank."
)
parser.add_argument(
    '--time_weight', default=0.1, type=float
)
parser.add_argument(
    '--blur_range', default=3, type=int
)
parser.add_argument(
    '--seed', default=929, type=int
)
parser.add_argument(
    '--log_dir', default='./logs', type=str,
    help="log dir"
)
parser.add_argument(
    '--load_name', default=None, type=str,
    help="Model file name to be loaded"
)
parser.add_argument(
    '--save_name', default=None, type=str,
    help="Model file name to be saved in"
)
parser.add_argument(
    '--cluster_num', default=500, type=int
)
parser.add_argument(
    '--cluster_method', default='kmeans', type=str,
    help="Cluster method"
)
parser.add_argument(
    '--cluster_element', default='feature', type=str,
    help="Cluster element"
)
parser.add_argument(
    '--cluster_distance', default='l2', type=str,
    help="Cluster element"
)
parser.add_argument(
    '--cluster_trainable', default=False, action='store_true'
)
parser.add_argument(
    '--pred_max_freq', default=-1, type=int,
    help="Max frequency of preserved predicates."
)
parser.add_argument(
    '--data_split', default='regular', type=str
)
parser.add_argument(
    '--score_mode', default='cos', type=str
)
parser.add_argument(
    '--curve_mode', default='l2', type=str
)
parser.add_argument(
    '--objective_mode', default='classification', type=str
)
parser.add_argument(
    '--omg_init', default='exp', type=str
)
parser.add_argument(
    '--omg_max', default=100, type=float
)
parser.add_argument(
    '--g_th', default=0.3, type=float
)
parser.add_argument(
    '--run_mode', default='train', type=str
)
parser.add_argument(
    '--tag', default='none', type=str
)
parser.add_argument(
    '--task', default='tam', type=str
)
parser.add_argument(
    '--query_count', default=1000, type=int
)
parser.add_argument(
    '--query_pair_freq', default=5, type=int
)
parser.add_argument(
    '--pos_score', default=1, type=float
)
parser.add_argument(
    '--online_interval', default=365, type=int
)
parser.add_argument(
    '--online_iteration', default=365, type=int
)
parser.add_argument(
    '--query_max_delta', default=365, type=int
)
parser.add_argument(
    '--rule', default='specific', type=str
)
parser.add_argument(
    '--se_prop', default=0.68, type=float,
    help="static proportion."
)
parser.add_argument(
    '--dropout', default=0.4, type=float,
    help="dropout."
)
parser.add_argument(
    '--neg_ratio', default=20, type=int,
    help="ratio of negative samples."
)
parser.add_argument(
    '--gamma', default=110, type=int,
    help="margin for translational models."
)
parser.add_argument(
    '--year', default=2014, type=int,
    help="number of days."
)
parser.add_argument(
    '--n_day', default=365, type=int,
    help="number of days."
)
parser.add_argument(
    '--cmin', default=0.003, type=float,
    help="minimum threshold of covariance matrices of ATISE."
)

args = parser.parse_args()

if args.load_name == None:
    save_dir = f'{args.log_dir}/tag:{args.tag}-{args.model}-{args.dataset}-int:{args.online_interval}-iter:{args.online_iteration}-en:{args.query_pair_freq}-p:{args.pos_score}-r:{args.rank}-e:{args.max_epochs}-lr:{args.learning_rate}-omg:{args.omg_max}-init:{args.omg_init}'
else:
    save_dir = args.load_name

if not os.path.exists(save_dir):
    os.system('mkdir ' + save_dir)
if not os.path.exists(save_dir + '/rule_plot'):
    os.system('mkdir ' + save_dir + '/rule_plot')
else:
    os.system('rm -r ' + save_dir + '/rule_plot')
    os.system('mkdir ' + save_dir + '/rule_plot')


logger = logging.getLogger('Logger')
logger.setLevel(logging.DEBUG)

sh = logging.StreamHandler()
sh.setLevel(logging.DEBUG)

fg = logging.FileHandler(filename=f'./run_time.log', mode='w')
fg.setLevel(logging.DEBUG)

fh = logging.FileHandler(filename=save_dir + f'/{args.run_mode}.log', mode='w')
fh.setLevel(logging.DEBUG)
fmt_s = logging.Formatter(fmt="%(message)s")
fmt_f = logging.Formatter(fmt="%(asctime)s - %(name)s - %(levelname)-5s - %(filename)-8s : %(lineno)s line - %(message)s", \
                            datefmt="%Y/%m/%d %H:%M:%S")
fh.setFormatter(fmt_f)
fg.setFormatter(fmt_f)
sh.setFormatter(fmt_s)
logger.addHandler(sh)
logger.addHandler(fh)
logger.addHandler(fg)
logger.info(f'Save logs @ {save_dir}')


def cos_threshold(x):
    return (x > 0) * x


def conj(x):
    return torch.stack((x[:, 0], -x[:, 1]), 1)


# def sym(x):
#     return torch.stack((x[:, 1], x[:, 0]), 1)


def mul(a, b):
    real = a[:, 0] * b[:, 0] - a[:, 1] * b[:, 1]
    img  = a[:, 0] * b[:, 1] + a[:, 1] * b[:, 0]
    return torch.stack((real, img), 1)


def div(a, b):
    den  = b[:, 0] ** 2 + b[:, 1] ** 2
    real = (a[:, 0] * b[:, 0] + a[:, 1] * b[:, 1]) / (den + 1e-9)
    img  = (a[:, 1] * b[:, 0] - a[:, 0] * b[:, 1]) / (den + 1e-9)
    return torch.stack((real, img), 1)


def avg_both(mrrs: Dict[str, float], hits: Dict[str, torch.FloatTensor], ranks):
    """
    aggregate metrics for missing lhs and rhs
    :param mrrs: d
    :param hits:
    :return:
    """
    m = (mrrs['lhs'] + mrrs['rhs']) / 2.
    h = (hits['lhs'] + hits['rhs']) / 2.
    ranks = (ranks['lhs'] + ranks['rhs']) / 2.
    return {'MRR': m, 'hits@[1,3,10]': h, 'ranks': ranks}

def eval_model(model, dataset):
    model.eval()
    if dataset.has_intervals():
        valid, test, train = [
            dataset.eval(model, split, -1 if split != 'train' else 50000)
            for split in ['valid', 'test', 'train']
        ]
        logger.info(f"valid: {valid}")
        logger.info(f"test: {test}")
        logger.info(f"train: {train}")

    else:
        valid, test, train = [
            avg_both(*dataset.eval(model, split, -1 if split != 'train' else 50000))
            for split in ['valid', 'test', 'train']
        ]
        logger.info(f"valid: {valid}")
        logger.info(f"test: {test}")
        logger.info(f"train: {train}")


class LocalBlurLoss(torch.nn.Module):
    def __init__(self, num_classes = 365, effective_range = (0, 365), blur_range: int = 3):
        super().__init__()
        self.num_classes = num_classes
        self.effective_range = effective_range
        self.effective_classes = effective_range[1] - effective_range[0]
        self.weight = 1
        std = blur_range / 3
        blur_kernel = torch.FloatTensor(signal.gaussian(blur_range, std)).cuda()
        logger.info(f'blur_kernel: {blur_kernel}')
        blur_kernel = blur_kernel / blur_kernel.sum()
        blur_kernel = torch.nn.Parameter(blur_kernel.reshape((1, 1, blur_range)), requires_grad=False)
        self.blur_conv = nn.Conv1d(1, 1, blur_range, padding = 'same', bias = False).cuda()
        # self.pos_score = args.pos_score
        # self.neg_score = args.neg_score
        
        self.pos_score = math.sqrt(args.rank) * args.pos_score
        self.neg_score = 0

        with torch.no_grad():
            self.blur_conv.weight = blur_kernel

    def invariable(self, x):
        return (x - x.detach()) ** 2 + 1e-9

    def forward(self, scores, labels):
        # logger.info(f"scores: {scores.shape}, labels: {labels.shape}")
        # logger.info(f"effective_range: {self.effective_range}, labels: {labels.max()}")
        scores = scores[:, self.effective_range[0] : self.effective_range[1]]
        labels = labels - self.effective_range[0]
        one_hot = F.one_hot(labels, num_classes=self.effective_classes)
        # one_hot_b = self.blur_conv(one_hot.t().float().unsqueeze(1)).squeeze().t()
        one_hot_b = one_hot
        ground_truth = one_hot_b * self.pos_score + (1 - one_hot_b) * self.neg_score
        dist = (scores - ground_truth) ** 2 + 1e-9
        if self.effective_range[0] > 0:
            invar = self.invariable(scores[: , : self.effective_range[0]]).mean()# + self.invariable(scores[:, self.effective_range[1] : ]).mean()
        else:
            invar = 0
        assert labels.max() < self.effective_classes
        # logger.info(f"one hot: {one_hot}")
        assert not torch.isnan(one_hot).any()
        assert one_hot.sum() > 0
        if (1 - one_hot).sum() == 0:
            logger.info(f'one hot: {one_hot.shape}')
            logger.info(f'{one_hot[:10]}')
        assert (1 - one_hot).sum() > 0
        return (dist * one_hot).sum() / one_hot.sum() + (dist * (1 - one_hot)).sum() / (1 - one_hot).sum()# + invar
        # ground_truth = self.blur_conv(ground_truth.t().float().unsqueeze(1)).squeeze().t()
        # print(f'scores: {scores.shape}, labels: {labels.shape}, ground_truth: {ground_truth.shape}')


def get_tri_enc(sizes):
    return lambda e : (int(e[0]) * int(sizes[1]) + int(e[1])) * int(sizes[2]) + int(e[2])


def get_tri_dec(sizes):
    return lambda x : (int(x // sizes[2] // sizes[1]), int(x // sizes[2] % sizes[1]), int(x % sizes[2]))


def get_valid_so(so_to_freq, thd, ent_cnt):
    s_list, o_list = [], []
    for so, freq in so_to_freq.items():
        s, o = so // ent_cnt, so % ent_cnt
        if freq >= thd:
            s_list += [s]
            o_list += [o]
    return np.stack((np.array(s_list), np.array(o_list)))


def get_valid_rel_pairs(sizes, tri_to_freq, thd):
    so_to_p = defaultdict(list)
    pairs = set()
    dec = get_tri_dec(sizes)
    for tri, freq in tri_to_freq.items():
        if freq < thd:
            continue
        s, p, o = dec(tri)
        so_to_p[s * sizes[0] + o] += [p]
    for so, p_list in so_to_p.items():
        for i in range(len(p_list)):
            for j in range(i + 1, len(p_list)):
                pairs.add((p_list[i], p_list[j]))
                pairs.add((p_list[j], p_list[i]))
    
    p1_list, p2_list = [], []
    for (p1, p2) in pairs:
        p1_list += [p1]
        p2_list += [p2]
    return np.stack((np.array(p1_list), np.array(p2_list)))


def count_s_o(events, ent_cnt):
    s_to_o = defaultdict(set)
    for e in events:
        s, o = e[0], e[2]
        s_to_o[s].add(o)
    return s_to_o


def count_so(events, ent_cnt):
    so_to_freq = defaultdict(int)
    for e in events:
        s, o = e[0], e[2]
        so_to_freq[s * ent_cnt + o] += 1
    return so_to_freq


def count_so_p(events, ent_cnt):
    so_to_p = defaultdict(set)
    for e in events:
        s, p, o = e[0], e[1], e[2]
        so_to_p[(int(s), int(o))].add((int(p)))
    return so_to_p


def count_predicates(events, pred_cnt):
    pred_to_freq = np.zeros((pred_cnt))
    for e in events:
        pred_to_freq[ e[1] ] += 1
    
    # for x in sorted(list(pred_to_freq)):
    #     logger.info(x)
    return pred_to_freq


def count_triples(sizes, events):
    tri_to_freq = defaultdict(int)
    tri_to_time = defaultdict(list)
    freq_to_tri = defaultdict(list)
    enc = get_tri_enc(sizes)

    for e in events:
        x = enc(e)
        tri_to_freq[x] += 1
        tri_to_time[x] += [int(e[3])]

    for e, f in tri_to_freq.items():
        freq_to_tri[f] += [e]

    return {'tri_to_freq': tri_to_freq, 'freq_to_tri': freq_to_tri, 'tri_to_time': tri_to_time}


def count_entities(events):
    ent_to_freq = defaultdict(int)
    freq_to_ent = defaultdict(list)

    for e in events:
        ent_to_freq[ e[0] ] += 1
        ent_to_freq[ e[2] ] += 1

    for e, f in ent_to_freq.items():
        freq_to_ent[f] += [e]

    return {'ent_to_freq': ent_to_freq, 'freq_to_ent': freq_to_ent}


def analyze(dataset, model):
    logger.info('Entities:')

    sizes = dataset.get_shape()
    ent_to_id = dataset.get_ent_id()
    train_data, test_data = dataset.get_train(), dataset.get_examples('test')
    train_data = train_data[: len(train_data) // 2]
    dec = get_tri_dec(sizes)
    enc = get_tri_enc(sizes)

    # Frequency of Entities & Triples
    train_ent_dicts = count_entities(train_data)
    train_tri_dicts = count_triples(sizes, train_data)
    test_ent_dicts = count_entities(test_data)
    test_tri_dicts = count_triples(sizes, test_data)

    ent_unique_freqs = {
        'train': list(train_ent_dicts['freq_to_ent'].keys()),
        'test': list(test_ent_dicts['freq_to_ent'].keys())
    }
    tri_unique_freqs = {
        'train': list(train_tri_dicts['freq_to_tri'].keys()),
        'test': list(test_tri_dicts['freq_to_tri'].keys())
    }

    # Cluster results
    # labels = model.density_cluster(ent_to_id)
    # for i in range(labels.shape[0]):
    #     if labels[i] > -1:
    #         logger.info(f"Label: {labels[i]}, Entity: {ent_to_id[i]}, Train Freq: {train_ent_dicts['ent_to_freq'][i]}, Test Freq: {test_ent_dicts['ent_to_freq'][i]}")


def cosine_distance(x: torch.Tensor):
    norm = torch.norm(x, dim=1, p=2, keepdim=True)
    x = x / (norm + 1e-8)
    return (-x @ x.t() + 1.01) / 2


def get_arg_matrix(x: torch.Tensor):
    rank = x.shape[1]
    return torch.atan(x[:, rank // 2 : ] / x[:, : rank // 2])


def cluster_constrained(features: np.ndarray, cluster_num: int, cluster_distance: str):
    # if cluster_distance == 'cos':
    #     features = get_arg_matrix(torch.Tensor(features)).cpu().numpy()
    clf = KMeansConstrained(n_clusters=cluster_num, size_min=0, size_max=200, n_init=5, n_jobs=-1, copy_x=False, verbose=True)
    clf.fit_predict(features)
    labels = clf.labels_
    centers = np.array(clf.cluster_centers_)
    centroids = centers[labels]
    logger.info(f"{centroids.shape}")
    return labels, centroids


def cluster_pytorch_kmeans(features: torch.Tensor, cluster_num: int):
    labels, centers = kmeans(features, cluster_num, distance = 'cosine', device = features.device)
    return labels.cpu().numpy(), centers.cpu().numpy()


def cluster_sklearn_kmeans(features: np.ndarray, cluster_num: int):
    model = cluster.KMeans(n_clusters = cluster_num, max_iter = 100, verbose = True)
    results = model.fit(features)
    return results.labels_, results.cluster_centers_


def cluster_scipy_kmeans2(features: np.ndarray, cluster_num: int):
    result = kmeans2(features, cluster_num)
    centroids, labels = result
    return labels, centroids


def cluster_DBSCAN(features: torch.Tensor):
    dist = cosine_distance(features).cpu().numpy()
    logger.info(f'dist min: {dist.min()}, dist max: {dist.max()}, dist mean: {dist.mean()}, dist var: {dist.var()}')

    result = cluster.DBSCAN(eps = 0.2, min_samples = 2, metric = 'precomputed').fit(dist)
    labels = result.labels_
    return labels


def cluster_hierarchy(features: np.ndarray, cluster_num: int):
    linkage = hcluster.linkage(features, method='centroid')
    labels = hcluster.fcluster(linkage, t=cluster_num, criterion='maxclust') - 1
    return labels


def plot_time_curve(time_scores, file_name, color='b', y_max=-1):
    plt.figure()
    plt.plot(np.arange(0, time_scores.shape[0]), time_scores, c = color)
    # if y_max > 0:
    #     plt.ylim((0, y_max))
    # else:
    #     y_max = time_scores.max()
    #     plt.ylim((0, y_max))
    plt.savefig(f'{save_dir}/{file_name}')
    plt.close()


def calc_gc(c_l):
    if c_l[2] == 0:
        return 0
    pre, rec = c_l[2] / c_l[0], c_l[2] / c_l[1]
    return 2 * pre * rec / (pre + rec + 1e-9)


def calc_pair_gc(delta, dt, occur_s, occur_o):
    occur_s = sorted(occur_s)
    occur_o = sorted(occur_o)
    if len(occur_s) <= 1:
        return [len(occur_s), len(occur_o), 0]
    if len(occur_o) <= 1:
        return [len(occur_s), len(occur_o), 0]

    ### Added
    # if (occur_s[-1] - occur_s[0]) < 50:
    #     return [len(occur_s), len(occur_o), 0]
    # if (occur_o[-1] - occur_o[0]) < 50:
    #     return [len(occur_s), len(occur_o), 0]
    # if len(occur_s) - len(occur_o) > 5 or len(occur_o) - len(occur_s) > 5:
    #     return [len(occur_s), len(occur_o), 0]
    # if occur_s[0] == occur_o[0] and occur_o[-1] == occur_s[-1]:
    #     return [len(occur_s), len(occur_o), 0]
    ###

    match = 0
    j = 0
    for i in range(len(occur_s)):
        t1 = occur_s[i]
        while j < len(occur_o):
            t2 = occur_o[j]
            j += 1
            if dt - delta <= t2 - t1 <= dt + delta:
                match += 1
                break

    return [len(occur_s), len(occur_o), match]


def evaluate_ta(rule, tri_to_time, tri_enc):
    if args.rule == 'specific':
        s, o, p1, p2, dt = rule
        s1, s2 = s, s
        o1, o2 = o, o
    elif args.rule == 'general':
        s1, p1, o1, s2, p2, o2, dt = rule
    eta = 0.1
    delta = math.ceil(dt * eta)
    code1 = tri_enc([s1, p1, o1])
    code2 = tri_enc([s2, p2, o2])
    gc_pair = calc_pair_gc(delta, dt, tri_to_time[code1], tri_to_time[code2])
    return gc_pair


def plot_artificial(model, artifacts):
    pos_ex, neg_ex = artifacts[0], artifacts[-1]
    pos_rule_score = model.calc_ta_curve([pos_ex[0]], [pos_ex[1]], [pos_ex[2]], [pos_ex[3]]).squeeze()
    neg_rule_score = model.calc_ta_curve([neg_ex[0]], [neg_ex[1]], [neg_ex[2]], [neg_ex[3]]).squeeze()

    font_size_1=20
    font_size_2=15

    logger.info(f'Pos, Neg rule shape: {pos_rule_score.shape}, {neg_rule_score.shape}')
    
    logger.info(f'Rule Single  ({pos_ex}), max point: {pos_rule_score.argmax()}')
    logger.info(f'Rule Single  ({neg_ex}), max point: {neg_rule_score.argmax()}')

    pos_rule_exp = torch.softmax(pos_rule_score, 0).cpu().numpy()
    neg_rule_exp = torch.softmax(neg_rule_score, 0).cpu().numpy()
    
    pos_rule_dis = -pos_rule_score.cpu().numpy()
    neg_rule_dis = -neg_rule_score.cpu().numpy()
    
    plt.figure()
    plt.tick_params(labelsize=font_size_2)
    plt.plot(np.arange(0, pos_rule_dis.shape[0]), pos_rule_dis, c = 'r', label='TA with high gc')
    plt.plot(np.arange(0, neg_rule_dis.shape[0]), neg_rule_dis, c = 'b', label='TA with low gc')
    plt.legend(fontsize=font_size_1)
    plt.xlabel('Timestamps', fontsize=font_size_1)
    plt.ylabel('Distance', fontsize=font_size_1)
    plt.savefig(f'{save_dir}/taq-distance.png', bbox_inches='tight')
    # plt.ylim((0, y_max))
    plt.close()

    
    plt.figure()
    plt.tick_params(labelsize=font_size_2)
    plt.plot(np.arange(0, pos_rule_exp.shape[0]), pos_rule_exp, c = 'r', label='TA with high gc')
    plt.plot(np.arange(0, neg_rule_exp.shape[0]), neg_rule_exp, c = 'b', label='TA with low gc')
    plt.legend(fontsize=font_size_1)
    plt.xlabel('Timestamp', fontsize=font_size_1)
    plt.ylabel('Probability', fontsize=font_size_1)
    plt.savefig(f'{save_dir}/taq-probability.png', bbox_inches='tight')
    plt.close()


def find_artifical_rule(model, artifacts):
    ## rule: (s, o, r1, r2, dt, mode)
    # artifacts += [(101, 102, 50, 80, 100, 'f', 10)]
    logger.info(artifacts[:3])
    for s, o, r1, r2, dt, _ in artifacts:
        # rule_scores = model.calc_er_curve([ls], [lo], [r1], [r2]).squeeze()
        single_rule_scores = model.calc_ta_curve([s], [o], [r1], [r2]).squeeze()
        # logger.info(rule_scores)
        # logger.info(f'Rule Cluster ({ls}, {lo}, {r1}, {r2}, {dt}, {mode}), max point: {rule_scores.argmax()}')
        logger.info(f'Rule Single  ({s}, {o}, {r1}, {r2}, {dt}), max point: {single_rule_scores.argmax()}')
        # logger.info(f'ent_to_label: {s}-> {model.ent_to_lbl[s]}, {o}-> {model.ent_to_lbl[o]}')

        color = 'r' if dt > 0 else 'b'

        # plot_time_curve(rule_scores.cpu().numpy(), f'art_rule_({ls}, {lo}, {r1}, {r2}, {dt}, {mode}).png', color=color)
        plot_time_curve(-single_rule_scores.cpu().numpy(), f'rule_plot/art_s_rule_({s}, {o}, {r1}, {r2}, {dt}).png', color=color)
        plot_time_curve(torch.softmax(single_rule_scores, 0).cpu().numpy(), f'rule_plot/art_exp_rule_({s}, {o}, {r1}, {r2}, {dt}).png', color=color, y_max=1)

        time_scores_1 = model.calc_event_curve(s, r1, o).squeeze()
        plot_time_curve(time_scores_1.cpu().numpy(), f'rule_plot/event_({s}, {r1}, {o}).png', y_max=-1)

        if True:
            time_scores_2 = model.calc_event_curve(s, r2, o).squeeze()
            plot_time_curve(time_scores_2.cpu().numpy(), f'rule_plot/event_({s}, {r2}, {o}).png')
        else:
            time_scores_2 = model.calc_event_curve(o, r2, s).squeeze()
            plot_time_curve(time_scores_2.cpu().numpy(), f'event_({o}, {r2}, {s}).png')