from DeepCCASoftHGR_Corr import CCAscore, DeepCCAscore, SoftCCAscore, SoftHGR, UniFastHGR, OptFastHGR
import torch
import time
from tqdm import tqdm
import matplotlib
matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt

from distcorr import distcorr
from utils.metrics import id_correlation
from CKA import linear_CKA, kernel_CKA

bz = 128
el = [10, 50, 100, 150, 200, 300, 400, 500]
ttimes = 10000

mcosttime_CCA = []
mcosttime_DeepCCA = []
mcosttime_SoftCCA = []
mcosttime_CKA = []
mcosttime_distcorr = []
mcosttime_idcorr = []
mcosttime_SoftHGR = []
mcosttime_UniFastHGR = []
mcosttime_OptFastHGR = []
for dim in el:
    time_CCA = 0
    time_DeepCCA = 0
    time_SoftCCA = 0
    time_CKA = 0
    time_distcorr = 0
    time_idcorr = 0
    time_SoftHGR = 0
    time_UniFastHGR = 0
    time_OptFastHGR = 0
    for i in tqdm(range(ttimes)):
        f = torch.rand((bz, dim))
        g = torch.rand((bz, dim))

        start_time = time.perf_counter()
        CCAscore(f, g)
        duration = time.perf_counter() - start_time
        time_CCA += duration

        start_time = time.perf_counter()
        DeepCCAscore(f, g)
        duration = time.perf_counter() - start_time
        time_DeepCCA += duration

        start_time = time.perf_counter()
        SoftCCAscore(f, g)
        duration = time.perf_counter() - start_time
        time_SoftCCA += duration

        start_time = time.perf_counter()
        kernel_CKA(f, g)
        duration = time.perf_counter() - start_time
        time_CKA += duration

        start_time = time.perf_counter()
        distcorr(f, g)
        duration = time.perf_counter() - start_time
        time_distcorr += duration

        start_time = time.perf_counter()
        id_correlation(f, g)
        duration = time.perf_counter() - start_time
        time_idcorr += duration

        start_time = time.perf_counter()
        SoftHGR(f, g)
        duration = time.perf_counter() - start_time
        time_SoftHGR += duration

        start_time = time.perf_counter()
        UniFastHGR(f, g)
        duration = time.perf_counter() - start_time
        time_UniFastHGR += duration

        start_time = time.perf_counter()
        OptFastHGR(f, g)
        duration = time.perf_counter() - start_time
        time_OptFastHGR += duration

    time_CCA /= ttimes
    time_DeepCCA /= ttimes
    time_SoftCCA /= ttimes
    time_CKA /= ttimes
    time_distcorr /= ttimes
    time_idcorr /= ttimes
    time_SoftHGR /= ttimes
    time_UniFastHGR /= ttimes
    time_OptFastHGR /= ttimes
    mcosttime_CCA.append(time_CCA)
    mcosttime_DeepCCA.append(time_DeepCCA)
    mcosttime_SoftCCA.append(time_SoftCCA)
    mcosttime_CKA.append(time_CKA)
    mcosttime_distcorr.append(time_distcorr)
    mcosttime_idcorr.append(time_idcorr)
    mcosttime_SoftHGR.append(time_SoftHGR)
    mcosttime_UniFastHGR.append(time_UniFastHGR)
    mcosttime_OptFastHGR.append(time_OptFastHGR)
print('CCA:')
print(mcosttime_CCA)
print('DeepCCA:')
print(mcosttime_DeepCCA)
print('SoftCCA:')
print(mcosttime_SoftCCA)
print('CKA:')
print(mcosttime_CKA)
print('distcorr:')
print(mcosttime_distcorr)
print('idcorr:')
print(mcosttime_idcorr)
print('SoftHGR:')
print(mcosttime_SoftHGR)
print('UniFastHGR:')
print(mcosttime_HGRscore3)
print('OptFastHGR:')
print(mcosttime_HGRscore4)

