import os
import torch
import pandas as pd
from safetensors.torch import load_file,save_file
import numpy as np
import torch
from torch import nn
import re

    

def get_K_hw(real_dict,cache_dict,k_threshold=1, add_mask=False):
    K_res= []
    penalty = 1e-8
    sorted_list = sorted(real_dict.keys(), key=lambda x: tuple(int(num) for num in re.findall(r'\d+', x)))
    for key in sorted_list:
        real = real_dict[key]
        cache = cache_dict[key]
        real = real.squeeze(0)
        cache = cache.squeeze(0)
        lambda_ = 0.01 * (torch.mean(cache**2)/ torch.var(real))
        # latent_channel, frame, height, width = real.shape
        ### torch.Size([16, 21])
        k_t = (cache**2 - real * cache) / (cache**2 + lambda_ + penalty)
        K_res.append(k_t.unsqueeze(0))
    
    out_K = torch.concat(K_res, 0)
    return out_K

    
def get_mean_K_hw(func, data_folder_ori, data_folder_cache ,save_file_name, k_threshold=1, add_mask=False, seed=0, num=None):
    res_K = []
    res_B =[]
    cnt = 0
    
    if num:
        file_folder = os.listdir(data_folder_cache)
        file_folder= [i for i in file_folder if f'seed{seed}' in i][0:num]
    else:
        file_folder = os.listdir(data_folder_cache)


    for path in file_folder:
        if f'seed{seed}' in path:
            print(path)
            real_dict = load_file(os.path.join(data_folder_ori,path))
            cache_dict = load_file(os.path.join(data_folder_cache,path))
            out_K = func(real_dict,cache_dict,k_threshold,add_mask)
                
            res_K.append(out_K.unsqueeze(0))
            cnt+=1
    
    res = torch.concat(res_K,0) 
    res = torch.mean(res,0).squeeze(0)
    save_file({'output':res},save_file_name)
    
  
    

data_folder_ori = os.getenv('DATA_FOLDER_ORI','')
data_folder_cache = os.getenv('DATA_FOLDER_CACHE','')
for seed in range(1):
    save_file_name =  os.path.join(os.path.dirname(data_folder_cache),f'K_hw_{seed}.safetensors')
    get_mean_K_hw(func = get_K_hw,data_folder_ori = data_folder_ori, data_folder_cache = data_folder_cache, save_file_name = save_file_name, k_threshold=1, add_mask=False,seed=seed)
    