import numpy as np
import pandas as pd
from matplotlib import figure
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns

import math
import os
import time
import random
import gc

import warnings
warnings.filterwarnings(action='ignore')

import torch
from torch import nn
from torch.utils.data import Dataset
import torch.nn.functional as F
import copy

random_seed=10
SEED=10
torch.manual_seed(random_seed) # for torch.~~
torch.backends.cudnn.deterministic = True # for deep learning CUDA library
torch.backends.cudnn.benchmark = False # for deep learning CUDA library
np.random.seed(random_seed) # for numpy-based backend, scikit-learn
random.seed(random_seed) # for python random library-based e.g., torchvision
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU

### Common Layer Visaulzliation ###
def layer_visualize(x, num_filter, layer, num=0, result_folder=None):
    x = x.detach().cpu()[num]
    try:
        r = int(math.sqrt(num_filter))
        c = int(math.sqrt(num_filter))
        x = x.reshape(r,c,1, -1)
        _, _, _, t = x.size()
    except:
        try:
            r = 8
            c = 10
            x = x.reshape(8,10,1,-1)
        except:
            r = 7
            c = 16
            x = x.reshape(7,16,1,-1)
        _, _, _, t = x.size()
    fig, ax = plt.subplots(r,c,dpi=150)
    for i in range(r):
        for j in range(c):
            ax[i, j].plot(x[i, j, :, :].T)
            ax[i, j].set_xticks([])
            ax[i, j].set_yticks([])
    fig.suptitle(layer+f"\n{num_filter}feature map (1box=1channel)", fontsize=15)
    fig.supxlabel(f"Time Length ({t})", fontsize=10)
    fig.supylabel("Values", fontsize=10)
    print(layer)
    plt.savefig(os.path.join(result_folder,f'{layer}_featuremap.pdf'))
    

def visualize_alignmatrix(x, y, num, class_num, result_folder, name):
    x = x.detach().cpu()[num]
    y = y.detach().cpu()[num]
    plt.figure(dpi=150)
    colors = ["white", "lightgray", 'gray'] 
    cmap = LinearSegmentedColormap.from_list('Custom', colors, len(colors))
    colors = ["white", "blue"] 
    cmap2 = LinearSegmentedColormap.from_list('Custom', colors, len(colors))
    ax = sns.heatmap(x, cbar=True, cmap=cmap2)
    h = ax.get_yticks()
    w = ax.get_xticks()
    for i in range(1,4):
        ax.hlines((h[i-1]+h[i])/2, w[0], x.size(-1),ls='--', linewidth=1, color="black")
    A_ = x#.copy()
    time_list = []
    t = 0
    for j in range(4):
        while True:
            if t == A_.size(1)-1:
                break
            if A_[j][t].item() == 0:
                time_list.append(t)
                break
            else:
                t = t+1
    for tmp in time_list:
        ax.axvline(x=tmp, ls='--', linewidth=1, color='black')
    #up_triang = np.triu(np.ones_like(y)).astype(bool)
    #sns.heatmap(x, mask=((x-y).numpy() != 0), cbar=False, cmap=cmap, ax=ax)
    plt.title(f"feature heatmap", fontsize=15)
    plt.xlabel(f"{num}sample's timelength", fontsize=10)
    plt.ylabel(f"{num}sample's {class_num} Prototype", fontsize=10)
    plt.savefig(os.path.join(result_folder, f'{name}.pdf'))
    

def visualize_attn(x, protos_num, num, class_num, result_folder, name, count):
    x = x.detach().cpu()[num].mean(axis=1).unsqueeze(0)
    n = protos_num
    f, ax = plt.subplots(figsize=(7,1.3))
    #sns.set(rc = {'figure.figsize':(10,1)})
    sns.heatmap(x, cbar=False, cmap='Reds', ax=ax, square=True)
    #cbar = ax.collections[0].colorbar
    #cbar.ax.tick_params(labelsize=15)
    ax.tick_params(axis='both',          
                    which='both',      
                    bottom=False,      
                    top=False,   
                    left=False,
                    labelbottom=True) 
    ax.axvline(x=protos_num, linewidth=3, color='black')
    ax.axvline(x=2*protos_num, linewidth=3, color='black')
    ax.set_xticks(np.arange(0,(protos_num*3), 0.5)) 
    labels = ['' for i in ax.get_xticks()]
    labels[n]='GTP'
    labels[(3*n)] = 'STP'
    labels[(5*n)] = 'DTP'
    ax.set_xticklabels(labels, rotation=0, fontsize=20)
    ax.set_yticklabels('')
    f.savefig(os.path.join(result_folder, f'{name}_{count[0]}epoch_{count[1]}batch.pdf'))#{count[0]}epoch_{count[1]}batch
    
    
    
    