import cv2
import numpy as np
import os
from utils.metric import calc_color_histogram, color_histogram_emd, histogram_intersection, ssim

def reinhard_color_transfer(source, target):
    # RGB to LAB
    source_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
    target_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
    
    # 채널별 평균과 표준편차 계산
    mean_src, std_src = cv2.meanStdDev(source_lab)
    mean_tar, std_tar = cv2.meanStdDev(target_lab)

    # 결과 이미지 생성
    result_lab = (source_lab - mean_src.T) / (std_src.T + 1e-6)  # 정규화
    result_lab = result_lab * std_tar.T + mean_tar.T  # target 통계로 스케일링
    result_lab = np.clip(result_lab, 0, 255).astype(np.uint8)

    # LAB to RGB
    result_bgr = cv2.cvtColor(result_lab, cv2.COLOR_LAB2BGR)
    return result_bgr


def match_histograms(source, target):
    matched = np.zeros_like(source)
    for c in range(3):  # BGR 채널 각각에 대해
        src_hist, bins = np.histogram(source[:, :, c].flatten(), 256, [0,256])
        tgt_hist, _ = np.histogram(target[:, :, c].flatten(), 256, [0,256])

        # CDF 계산
        src_cdf = np.cumsum(src_hist).astype(np.float32)
        src_cdf /= src_cdf[-1]
        tgt_cdf = np.cumsum(tgt_hist).astype(np.float32)
        tgt_cdf /= tgt_cdf[-1]

        # 매핑 함수 계산
        mapping = np.zeros(256, dtype=np.uint8)
        tgt_idx = 0
        for src_idx in range(256):
            while tgt_idx < 255 and tgt_cdf[tgt_idx] < src_cdf[src_idx]:
                tgt_idx += 1
            mapping[src_idx] = tgt_idx

        # 적용
        matched[:, :, c] = mapping[source[:, :, c]]
    return matched



input_problems = '00010'
output_problems = '2014-08-01 17:41:55.jpg'

method = 'hismatching' # 'hismatching' | 'reinhard'
method = method.lower()


input_problem = input_problems[i].lower()
output_problem = output_problems[i].lower()

# Load Images
input_path = f"./data/Images/modent2photo/testA/{input_problem}.jpg"
target_path = f"./data/Images/modent2photo/testB/{output_problem}.jpg"
source = cv2.imread(input_path)
target = cv2.imread(target_path)

# Forward
if method == 'reinhard':
    transported_image = reinhard_color_transfer(source, target)
    cv2.imwrite(f'./exps/Color_Transfer/Classical_Methods/{input_problem}_{output_problem}_Reinhard.jpg', transported_image)
elif method == 'hismatching':
    transported_image = match_histograms(source, target)
    cv2.imwrite(f'./exps/OT_MMD/Color_Transfer/Classicam_Methods/{input_problem}_{output_problem}_histogram.jpg', transported_image)
    
hist_target = calc_color_histogram(target)
hist_result1 = calc_color_histogram(transported_image)

emd_val = color_histogram_emd(hist_target, hist_result1)
inter_r = histogram_intersection(hist_target, hist_result1)

ssim_val = ssim(source, transported_image)
with open('./exps/Color_Transfer/Classical_Methods/results.txt', "a") as file:
    file.write(f"-- {method} {input_problem:>20} -> {output_problem:<20} EMD : {emd_val:.4f} | Hist Intersection: {inter_r:.4f} | SSIM: {ssim_val:.4f}\n")

# Backward
if method == 'reinhard':
    transported_image = reinhard_color_transfer(target, source)
    cv2.imwrite(f'./exps/OT_MMD/Color_Transfer/Classicam_Methods/{output_problem}_{input_problem}_Reinhard.jpg', transported_image)
elif method == 'hismatching':
    transported_image = match_histograms(target, source)
    cv2.imwrite(f'./exps/OT_MMD/Color_Transfer/Classicam_Methods/{output_problem}_{input_problem}_histogram.jpg', transported_image)

hist_source = calc_color_histogram(source)
hist_result1 = calc_color_histogram(transported_image)

emd_val = color_histogram_emd(hist_source, hist_result1)
inter_r = histogram_intersection(hist_source, hist_result1)
ssim_val = ssim(target, transported_image)

with open('./exps/Color_Transfer/Classical_Methods/results.txt', "a") as file:
    file.write(f"-- {method} {output_problem:>20} -> {input_problem:<20} EMD : {emd_val:.4f} | Hist Intersection: {inter_r:.4f} | SSIM: {ssim_val:.4f}\n")