from tqdm import tqdm
import argparse
import numpy as np
import torch
import pandas as pd
import sys
sys.path.append('/mnt/nvme_share/wuwl/project/CARZero-main/')
import CARZero
from transformers import AutoTokenizer
import cv2
from PIL import Image
import torchvision.transforms as transforms
import pickle
import torch.nn.functional as F
import torch.nn as nn
from test_medclip.test_medclip import MedCLIPVisionModelViT, MedCLIPModel
from test_medclip.constants import *
from test_medclip.MedCLIPProcessor import MedCLIPProcessor
from concurrent.futures import ThreadPoolExecutor


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default=None)
    parser.add_argument('--tokenizer_path', type=str, default=None)
    parser.add_argument('--ckpt_path', type=str, default=None)
    parser.add_argument('--test_data_path', type=str, default=None)
    parser.add_argument('--MIMIC_data_path', type=str, default=None)
    parser.add_argument('--report_corpus_path', type=str, default=None)
    parser.add_argument('--split_id', type=int, default=None)
    parser.add_argument('--num_chunks', type=int, default=None)
    parser.add_argument('--batch_size', type=int, default=None)
    args = parser.parse_args()
    return args

def split_list_into_chunks(lst, num_chunks):
    chunk_size = len(lst) // num_chunks
    chunks = [lst[i * chunk_size:(i + 1) * chunk_size] for i in range(num_chunks)]
    remainder = len(lst) % num_chunks
    for i in range(remainder):
        chunks[i].append(lst[num_chunks * chunk_size + i])
    return chunks

def get_imgs(img_path, transform=None):
    x = cv2.imread(str(img_path), 0)
    x = resize_img(x, 256)
    img = Image.fromarray(x).convert("RGB")
    if transform is not None:
        img = transform(img)
    return img

def resize_img(img, scale):
    size = img.shape
    max_dim = max(size)
    max_ind = size.index(max_dim)
    if max_ind == 0:
        wpercent = scale / float(size[0])
        hsize = int((float(size[1]) * float(wpercent)))
        desireable_size = (scale, hsize)
    else:
        hpercent = scale / float(size[1])
        wsize = int((float(size[0]) * float(hpercent)))
        desireable_size = (wsize, scale)
    resized_img = cv2.resize(img, desireable_size[::-1], interpolation=cv2.INTER_AREA)
    if max_ind == 0:
        pad_size = scale - resized_img.shape[1]
        left = int(np.floor(pad_size / 2))
        right = int(np.ceil(pad_size / 2))
        top = int(0)
        bottom = int(0)
    else:
        pad_size = scale - resized_img.shape[0]
        top = int(np.floor(pad_size / 2))
        bottom = int(np.ceil(pad_size / 2))
        left = int(0)
        right = int(0)
    resized_img = np.pad(resized_img, [(top, bottom), (left, right)], "constant", constant_values=0)
    return resized_img

def build_transformation(args):
    t = []
    if args.model_name == 'CARZero' or args.model_name == 'soft_label_plus':
        t.append(transforms.RandomCrop(224))
        t.append(transforms.RandomHorizontalFlip(0.3))
        t.append(transforms.RandomAffine(30, translate=[0.1, 0.1], scale=[0.9, 1.1]))
        t.append(transforms.ColorJitter(brightness=[0.8, 1.2], contrast=[0.8, 1.2]))
        t.append(transforms.ToTensor())
        t.append(transforms.Normalize(mean=[0.4978], std=[0.2449]))
    elif args.model_name == 'medclip':
        pass
    else:
        raise ValueError
    return transforms.Compose(t)

def batch_generator(data_list, batch_size):
    batch_size = min(batch_size, len(data_list))
    num_batches = (len(data_list) + batch_size - 1) // batch_size
    for i in range(num_batches):
        start_index = i * batch_size
        end_index = start_index + batch_size
        yield data_list[start_index:end_index]

def compute_logits(img_emb, text_emb):
    logit_scale = nn.Parameter(torch.log(torch.tensor(1 / 0.07)))
    logit_scale.data = torch.clamp(logit_scale.data, 0, 4.6052)
    logit_scale = logit_scale.exp()
    logits_per_text = torch.matmul(text_emb, img_emb.t()) * logit_scale
    return logits_per_text.t()


if __name__ == '__main__':

    args = parse_args()

    # Read Data
    with open(args.report_corpus_path, "rb") as f:
        print(f"Loading captions from {args.report_corpus_path}")
        path2sent, path2label, to_remove, label_ids = pickle.load(f)
    sent_list = []
    for key, value in path2sent.items():
        sent_list.append(value)
    report_corpus = [list(t) for t in set(tuple(sublist) for sublist in sent_list)]
    report_corpus_chunk = split_list_into_chunks(report_corpus, args.num_chunks)
    report_corpus_work = report_corpus_chunk[args.split_id - 1]
    test_data = pd.read_csv(args.test_data_path).values.tolist()
    del report_corpus, report_corpus_chunk, path2sent, path2label, to_remove, label_ids, sent_list, key, value, f

    # Init Model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    result = []

    if args.model_name == 'CARZero' or args.model_name == 'soft_label_plus':
        CARZero_model = CARZero.load_CARZero(name=args.ckpt_path, device=device)
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
        transform = build_transformation(args)
        # Model Inference
        with torch.no_grad():
            CARZero_model.eval()
            idx = 0
            for batch in batch_generator(test_data, args.batch_size):
                imgs = []
                with ThreadPoolExecutor() as executor:
                    imgs = list(executor.map(lambda test_data_item: get_imgs(test_data_item[0], transform), batch))
                SimR_batch = [[] for _ in range(len(batch))]
                for I in tqdm(range(len(report_corpus_work)), desc=f'当前子集: {args.split_id} | 当前批次: {idx}/{(len(test_data) + args.batch_size - 1) // args.batch_size}'):
                    ids, tokens, attention, cap_len = [], [], [], []
                    sent_list = report_corpus_work[I]
                    for sent in sent_list:
                        input_ids = tokenizer(sent, return_tensors="pt", truncation=True, padding="max_length", max_length=100)
                        x_len = len([t for t in input_ids["input_ids"][0] if t != 0])
                        ids.append(input_ids["input_ids"])
                        tokens.append(input_ids["token_type_ids"])
                        attention.append(input_ids["attention_mask"])
                        cap_len.append(x_len)
                    sorted_cap_lens, sorted_cap_indices = torch.sort(torch.tensor(cap_len), 0, True)
                    ids = torch.stack(ids).squeeze(1)
                    tokens = torch.stack(tokens).squeeze(1)
                    attention = torch.stack(attention).squeeze(1)
                    caption_ids = ids[sorted_cap_indices].to(device)
                    token_type_ids = tokens[sorted_cap_indices].to(device)
                    attention_mask = attention[sorted_cap_indices].to(device)
                    query_emb_l, query_emb_g, _ = CARZero_model.text_encoder_forward(caption_ids, attention_mask, token_type_ids)
                    label_img_emb_l, label_img_emb_g = CARZero_model.image_encoder_forward(imgs)
                    label_img_emb_l = label_img_emb_l.view(label_img_emb_l.size(0), label_img_emb_l.size(1), -1)
                    label_img_emb_l = label_img_emb_l.permute(0, 2, 1)
                    query_emb_l_ = query_emb_l.view(query_emb_l.size(0), query_emb_l.size(1), -1)
                    query_emb_l_ = query_emb_l_.permute(0, 2, 1)
                    i2t_cls = CARZero_model.fusion_module(torch.cat([label_img_emb_g.unsqueeze(1), label_img_emb_l], dim=1), query_emb_g, return_atten=False, use_MLP=True)
                    t2i_cls = CARZero_model.fusion_module(torch.cat([query_emb_g.unsqueeze(1), query_emb_l_], dim=1), label_img_emb_g, return_atten=False, use_MLP=True)
                    i2t_cls = i2t_cls.squeeze(-1)
                    t2i_cls = t2i_cls.squeeze(-1).transpose(1, 0)
                    SimR_item = (i2t_cls + t2i_cls) / 2
                    SimR_item = torch.sum(SimR_item, dim=1)
                    if args.model_name == 'soft_label_plus':
                        SimR_item = F.softmax(SimR_item, dim=1)
                        SimR_item = SimR_item[:, 0]
                    SimR_item = [round(float(x), 4) for x in SimR_item]
                    for i in range(len(SimR_batch)):
                        SimR_batch[i].append(SimR_item[i])
                for row in SimR_batch:
                    max_value = max(row)
                    max_index = row.index(max_value)
                    result.append({max_index: max_value})
                idx += 1

    elif args.model_name == 'medclip':
        model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
        state_dict = torch.load(VIT_Pretrain)
        model.load_state_dict(state_dict)
        model = model.to(device)
        processor = MedCLIPProcessor()
        transform = build_transformation(args)
        with torch.no_grad():
            model.eval()
            idx = 0
            for batch in batch_generator(test_data, args.batch_size):
                imgs = []
                with ThreadPoolExecutor() as executor:
                    imgs = list(executor.map(lambda test_data_item: get_imgs(test_data_item[0], transform), batch))
                SimR_batch = [[] for _ in range(len(batch))]
                for I in tqdm(range(len(report_corpus_work)), desc=f'当前子集: {args.split_id} | 当前批次: {idx}/{(len(test_data) + args.batch_size - 1) // args.batch_size}'):
                    ids, tokens, attention, cap_len = [], [], [], []
                    sent_list = report_corpus_work[I]
                    inputs = processor(
                        text=sent_list,
                        images=imgs,
                        return_tensors="pt",
                        padding=True
                    ).to(device)
                    SimR_item = torch.sum(model(**inputs)['logits'], dim=1)
                    SimR_item = [round(float(x), 4) for x in SimR_item]
                    for i in range(len(SimR_batch)):
                        SimR_batch[i].append(SimR_item[i])
                for row in SimR_batch:
                    max_value = max(row)
                    max_index = row.index(max_value)
                    result.append({max_index: max_value})
                idx += 1
    else:
        raise ValueError

    with open(f'/mnt/nvme_share/wuwl/project/CARZero-main/data/output/retrieval_based_report_generation/{args.model_name}/SimR_{args.split_id}.pkl', 'wb') as f:
        pickle.dump(result, f)
    print(f'子集{args.split_id}计算完成')