import numpy as np
import ot
import torch
import torchvision.utils as vutils
import matplotlib.pyplot as plt


# idx is added for easy to call, like names

Err = 'Error: the mat has no mass.'


class dataset(object):
    def __init__(self, mat_list, idx_list, time_list):
        self.mat = mat_list
        self.idx = idx_list
        self.time = time_list

class distribution:
    def __init__(self, locs, probs):
        self.locs = np.array(locs)
        self.probs = np.array(probs)

def find_globs(matrix, threshold = 0.01 ):
    '''
    The output are 
    '''

    rows, cols = matrix.shape
    visited = set()
    globs = []
    def dfs(r, c, value, glob):
        # Check if the current position is out of bounds or already visited
        if (r < 0 or r >= rows or c < 0 or c >= cols or 
                (r, c) in visited or matrix[r][c] <= threshold):
            return
        
        # Mark the position as visited and add to the current glob
        visited.add((r, c))
        glob.append((r, c))
        # Explore all 4 possible directions (up, down, left, right)
        dfs(r - 1, c, value, glob)  # Up
        dfs(r + 1, c, value, glob)  # Down
        dfs(r, c - 1, value, glob)  # Left
        dfs(r, c + 1, value, glob)  # Right

    for r in range(rows):
        for c in range(cols):
            if (r, c) not in visited:
                glob = []
                dfs(r, c, matrix[r][c], glob)
                if glob:  # Only add non-empty globs
                    globs.append(glob)
    return globs

def mat2dist(mat, threshold = 0.01, mode = "normal"):

    if mode == "globs":
        globs = find_globs(mat, threshold)
        if len(globs) == 0:
            return Err
        locs = []
        probs = []
        for cur_glob in globs:
            center = np.mean(cur_glob, axis = 0)
            locs.append(center)
            probs.append(len(cur_glob))
        return distribution(locs=locs, probs= np.array(probs)/np.sum(probs))
    elif mode == "normal":
        locations = np.where(mat > threshold)
        locations = list(zip(locations[0], locations[1]))
        tot = np.sum([mat[loc] for loc in locations])
        if tot < threshold:
            return "Error: the mat has no mass."
        probs = [mat[loc]/tot for loc in locations]
    return distribution(locs=locations, probs= probs)

def generate_seq_dist(dataset_dist, idx, seq_length, w = 1):
    locs = []
    probs = []
    for t in range(seq_length):
        cur_dist = dataset_dist[idx + t]
        time_locs = np.ones((len(cur_dist.locs),1))*w*t
        cur_locs = np.hstack((time_locs, cur_dist.locs))
        cur_probs = cur_dist.probs * 1/seq_length
        locs += [loc for loc in cur_locs]
        probs += [prob for prob in cur_probs]
    return distribution(locs = locs, probs = probs)

def seq_check(dist_list, cur_idx, seq_length):
    '''
    check if it is a seq or not
    '''
    for dist in dist_list[cur_idx: cur_idx + seq_length]:
        if dist == 0:
            return False
        if dist == Err:
            return False
    return True

### OT

def compute_sk(source_dist, target_dist, metric, shift = False):
    x, a = source_dist.locs, source_dist.probs
    y, b = target_dist.locs, target_dist.probs
    if shift:
        x_center = np.sum([x[i]*a[i] for i in range(len(x))], 0)
        y_center = np.sum([y[j]*b[j] for j in range(len(y))], 0)
        x = np.array([loc-x_center for loc in x])
        y = np.array([loc-y_center for loc in y])
    M = ot.dist(x, y,  metric=metric)
    value = ot.sinkhorn2(a, b, M, reg=0.1, stopThr=0.001)
    return value

def compute_OT(source_dist, target_dist, metric, shift = False):
    x, a = source_dist.locs, source_dist.probs
    y, b = target_dist.locs, target_dist.probs
    if shift:
        x_center = np.sum([x[i]*a[i] for i in range(len(x))], 0)
        y_center = np.sum([y[j]*b[j] for j in range(len(y))], 0)
        x = np.array([loc-x_center for loc in x])
        y = np.array([loc-y_center for loc in y])
    M = ot.dist(x, y,  metric=metric)
    return ot.emd2(a, b, M)

def downsampling(mat_list, factor = 4):
    return [mat[::factor, ::factor] for mat in mat_list]

def seq_images(mat_list, source_idx, seq_length):
    return mat_list[source_idx: source_idx + seq_length]

def display(images, nrow = 6, padding = 5, save_name = False):
    batch_images = torch.tensor(np.array(images))
    batch_images_with_channel = batch_images.unsqueeze(1)  # New shape will be (N, 1, H, W)
    grid = vutils.make_grid(batch_images_with_channel, nrow = nrow, padding=padding)
    npimg = grid.numpy()
    npimg = np.transpose(npimg,(1,2,0))
    npimg = npimg[:,:,0]
    fig = plt.figure(figsize=(nrow*5, nrow*3))
    plt.imshow(npimg,cmap= "jet", vmin = 0)
    plt.axis('off')
    if save_name:
        plt.savefig(f"{save_name}.pdf", format='pdf', bbox_inches='tight')
    plt.show()


def compare_and_sort(source_idx, seq_length, dataset_dist):
    #check the seq is valid or not first,
    # if valid, then use the distribution for compute, the distribution format has been prestore in MRMS_dist

    source_seq_dist = generate_seq_dist(dataset_dist, source_idx, seq_length)
    dis_and_idx = []
    idx_range = range(len(dataset_dist[:-seq_length]))
    for target_idx in tqdm(idx_range):
        if seq_check(dataset_dist, target_idx, seq_length):
            target_seq_dist = generate_seq_dist(dataset_dist, target_idx, seq_length)
            
            dis = compute_OT(source_seq_dist, target_seq_dist, metric="sqeuclidean", shift=True)
        else:
            dis = 100000
        dis_and_idx.append([dis,target_idx])
    sorted_dis_and_idx = sorted(dis_and_idx, key=lambda x: x[0])
    return sorted_dis_and_idx

def retrieving(sorted_dis_and_idx, top_k = 10, time_buffer = 100):
    idx_list = []
    for dis, cur_idx in sorted_dis_and_idx:
        not_close = all([abs(cur_idx - pred_idx) >= time_buffer for pred_idx in idx_list])
        if not_close:
            idx_list.append(cur_idx)
        if len(idx_list) >= top_k:
            break
    return idx_list

def collecting_for_display(idx_list, seq_length, mat_list, time_list):
    images = []
    time_names = []
    for idx in idx_list:
        images += mat_list[idx: idx + seq_length]
        time_names += time_list[idx: idx + seq_length]
    return images, time_names
