import os
import numpy as np
from PIL import Image
import cv2
import clip
import torch
from torchvision.transforms import functional as F
from model.ControlFusion import ControlFusion as create_model
from model.Classification import Classification
import argparse
from thop import profile, clever_format
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

def main(args):
    root_path = args.dataset_path
    save_path = args.save_path
    MODEL_PATH = './classification/model_epoch_5.pth'
    if os.path.exists(save_path) is False:
        os.makedirs(save_path)

    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    supported = [".jpg", ".JPG", ".png", ".PNG", ".bmp", 'tif', 'TIF']
    text_line = args.input_text

    visible_root = os.path.join(root_path, "vi")
    infrared_root = os.path.join(root_path, "ir")

    visible_path = [os.path.join(visible_root, i) for i in os.listdir(visible_root)
                  if os.path.splitext(i)[-1] in supported]
    infrared_path = [os.path.join(infrared_root, i) for i in os.listdir(infrared_root)
                  if os.path.splitext(i)[-1] in supported]

    visible_path.sort()
    infrared_path.sort()

    print("Find the number of visible image: {},  the number of the infrared image: {}".format(len(visible_path), len(infrared_path)))
    assert len(visible_path) == len(infrared_path), "The number of the source images does not match!"

    print("Begin to run!")
    with torch.no_grad():
        model_clip, _ = clip.load("ViT-L/14@336px", device=device)
        model = create_model(model_clip).to(device)
        Classification_model = Classification()
        Classification_model.load_state_dict(torch.load(MODEL_PATH), strict=False)
        Classification_model.to(device)
        Classification_model.eval()  # 切换到评估模式
        model_weight_path = args.weights_path
        model.load_state_dict(torch.load(model_weight_path, map_location=device)['model'])
        model.eval()

        # ================== 添加参数统计 ==================
        # 统计Classification模型
        cls_input = torch.randn(1, 3, 224, 224).to(device)  # 假设输入尺寸为224x224
        cls_macs, cls_params = profile(Classification_model, inputs=(cls_input,))

        # 统计ControlFusion模型
        dummy_vi = torch.randn(1, 3, 256, 256).to(device)  # 可见光输入尺寸
        dummy_ir = torch.randn(1, 3, 256, 256).to(device)  # 红外输入尺寸
        dummy_text = clip.tokenize(["dummy text"]).to(device)  # 文本输入
        dummy_feature = Classification_model(dummy_vi)  # 特征输入

        # 分解模型参数便于统计
        fusion_macs, fusion_params = profile(
            model,
            inputs=(dummy_vi, dummy_ir, dummy_text, dummy_feature),
            verbose=False
        )

        # 格式化输出
        cls_macs, cls_params = clever_format([cls_macs, cls_params], "%.3f")
        fusion_macs, fusion_params = clever_format([fusion_macs, fusion_params], "%.3f")

        print("\n========== 模型统计信息 ==========")
        print(f"Classification Model:")
        print(f"  Parameters: {cls_params}")
        print(f"  FLOPs: {cls_macs}")
        print(f"\nControlFusion Model:")
        print(f"  Parameters: {fusion_params}")
        print(f"  FLOPs: {fusion_macs}")
        print("==================================\n")
    for i in range(len(visible_path)):
        ir_path = infrared_path[i]
        vi_path = visible_path[i]

        img_name = vi_path.replace("\\", "/").split("/")[-1]
        assert os.path.exists(ir_path), "file: '{}' dose not exist.".format(ir_path)
        assert os.path.exists(vi_path), "file: '{}' dose not exist.".format(vi_path)

        ir = Image.open(ir_path).convert(mode="RGB")
        vi = Image.open(vi_path).convert(mode="RGB")

        height, width = vi.size
        new_width = width
        new_height = height

        ir = ir.resize((new_height, new_width))
        vi = vi.resize((new_height, new_width))

        ir = F.to_tensor(ir)
        vi = F.to_tensor(vi)

        ir = ir.unsqueeze(0).cuda()
        vi = vi.unsqueeze(0).cuda()
        img_feature = Classification_model(vi)
        with torch.no_grad():
            text = clip.tokenize(text_line).to(device)
            i = model(vi, ir, text, img_feature)
            fused_img_Y = tensor2numpy(i)
            save_pic(fused_img_Y, save_path, img_name)

        print("Save the {}".format(img_name))
    print("Finish! The results are saved in {}.".format(save_path))

def tensor2numpy(img_tensor):
    img = img_tensor.squeeze(0).cpu().detach().numpy()
    img = np.transpose(img, [1, 2, 0])
    return img

def mergy_Y_RGB_to_YCbCr(img1, img2):
    Y_channel = img1.squeeze(0).detach().cpu().numpy()
    Y_channel = np.transpose(Y_channel, [1, 2, 0])
    img2 = img2.squeeze(0).cpu().numpy()
    img2 = np.transpose(img2, [1, 2, 0])
    img2_YCbCr = cv2.cvtColor(img2, cv2.COLOR_RGB2YCrCb)
    CbCr_channels = img2_YCbCr[:, :, 1:]
    merged_img_YCbCr = np.concatenate((Y_channel, CbCr_channels), axis=2)
    merged_img = cv2.cvtColor(merged_img_YCbCr, cv2.COLOR_YCrCb2RGB)
    return merged_img

def save_pic(outputpic, path, index : str):
    outputpic[outputpic > 1.] = 1
    outputpic[outputpic < 0.] = 0
    outputpic = cv2.UMat(outputpic).get()
    outputpic = cv2.normalize(outputpic, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_32F)
    outputpic=outputpic[:, :, ::-1]
    save_path = os.path.join(path, index).replace(".jpg", ".png")
    cv2.imwrite(save_path, outputpic)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_path', type=str, default='', help='test data root path')
    parser.add_argument('--weights_path', type=str, default='', help='initial weights path')
    parser.add_argument('--save_path', type=str, default='', help='output save image path')
    parser.add_argument('--input_text', type=str, default="", help='text control input')

    parser.add_argument('--device', default='cuda', help='device (i.e. cuda or cpu)')
    parser.add_argument('--gpu_id', default='0', help='device id (i.e. 0, 1, 2 or 3)')
    opt = parser.parse_args()
    main(opt)
