import os

import numpy as np
import torch
import tqdm
from selector_model import (
    PairwiseComparator_C0L2,
    PairwiseComparator_C1L2,
    PairwiseComparator_C1L2_prompt,
    PairwiseComparator_C1L3,
    PairwiseComparator_C1L3_prompt,
    PairwiseComparator_C2L2,
    PairwiseComparator_C3L2,
)

os.environ["CUDA_VISIBLE_DEVICES"] = "7"
STEP_IDX = 6
# STEP_IDX = 9
NUM_SEED = 64


def load_latent_data(base_path, step_idx=STEP_IDX):
    latent_data = []

    print(f"Loading latent data from {base_path}...")
    for data_idx in tqdm.tqdm(range(100)):
        # for data_idx in range(10):
        latent_data_per_data = []
        # for seed_idx in range(64):
        for seed_idx in range(NUM_SEED):
            latent_path = os.path.join(base_path, f"sd3_num_inference_steps_10_seed_{seed_idx}/gsdiff_gobj83k_sd35m__render/inference/000_{data_idx:03d}_013020_gs_out_all.pt")
            # latent_path = os.path.join(base_path, f"sd3_num_inference_steps_20_seed_{seed_idx}/gsdiff_gobj83k_sd35m__render/inference/000_{data_idx:03d}_013020_gs_out_all.pt")
            data = torch.load(latent_path, weights_only=True)
            # print(len(data)) # 10
            # print(data[0].shape) # torch.Size([4, 16, 32, 32])
            latent_data_per_data.append(data[step_idx].cpu())
        latent_data.append(torch.stack(latent_data_per_data, dim=0))

    latent_data = torch.stack(latent_data, dim=0)  # torch.Size([100, 20, 4, 16, 32, 32])

    # torch.save(latent_data, "latent_data_t3_single_data100_seed64.pt")
    # torch.save(latent_data, f"latent_data_t3_single_data100_seed{NUM_SEED}.pt")
    # torch.save(latent_data, f"latent_data_t3_surr_data100_seed{NUM_SEED}.pt")
    torch.save(latent_data, f"latent_data_t3_multi_data100_seed{NUM_SEED}.pt")
    return latent_data


def load_score_data(base_path):
    score_data = []
    for seed_idx in range(NUM_SEED):
        score_path = os.path.join(base_path, f"sd3_num_inference_steps_10_seed_{seed_idx}/clip_similarity_scores.txt")
        # score_path = os.path.join(base_path, f"sd3_num_inference_steps_10_seed_{seed_idx}/aesthetic_score.txt")
        with open(score_path, "r") as f:
            clip_scores = [float(line.strip()) for line in f.readlines()]
        score_data.append(clip_scores)

    score_data = np.array(score_data).transpose()
    print(score_data.shape)  # (100, 64)
    return score_data


def load_score_data_image_reward(base_path):
    score_data = []
    for seed_idx in range(NUM_SEED):
        score_path = os.path.join(base_path, f"sd3_num_inference_steps_10_seed_{seed_idx}/image_reward_scores.txt")
        with open(score_path, "r") as f:
            clip_scores = [float(line.strip()) for line in f.readlines()]
        score_data.append(clip_scores)

    score_data = np.array(score_data).transpose()
    print(score_data.shape)  # (100, 64)
    return score_data


def load_score_data_aesthetic():
    base_path = "/path/to/score_results_t3_single/"
    score_data = []
    for seed_idx in range(NUM_SEED):
        score_path = os.path.join(base_path, f"sd3_num_inference_steps_10_seed_{seed_idx}/aesthetic_score.txt")
        with open(score_path, "r") as f:
            clip_scores = [float(line.strip()) for line in f.readlines()]
        score_data.append(clip_scores)

    score_data = np.array(score_data).transpose()
    print(score_data.shape)  # (100, 64)
    return score_data


def calculate_clip_rprec(winner_idx_list):
    from run_score import load_4view_images
    from src.utils.metrics import TextConditionMetrics

    all_image_list_1 = []
    all_image_list_2 = []
    all_image_list_3 = []
    all_image_list_4 = []
    for image_idx, winner_idx in enumerate(winner_idx_list):
        image_path = f"/path/to/sd3_num_inference_steps_10_seed_{winner_idx}/gsdiff_gobj83k_sd35m__render/inference/000_{image_idx:03d}_013020_gs_all.png"

        images = load_4view_images(image_path)
        # print(len(images)) # 4
        all_image_list_1.append(images[0])
        all_image_list_2.append(images[1])
        all_image_list_3.append(images[2])
        all_image_list_4.append(images[3])

    # all_image_list = torch.stack(all_image_list, dim=0)
    # print(all_image_list.shape)

    with open("/path/to/t3bench_single.txt", "r") as f:
        prompts = [line.strip() for line in f.readlines() if line.strip() != ""]

    text_condition_metrics = TextConditionMetrics(device_idx=7)

    all_cliprprec = []
    for image_list in [all_image_list_1, all_image_list_2, all_image_list_3, all_image_list_4]:
        clipsim, cliprprec, imagereward = text_condition_metrics.evaluate(
            image_list,
            prompts,
        )
        all_cliprprec.append(cliprprec)
        print(f"clipsim: {clipsim}, cliprprec: {cliprprec}, imagereward: {imagereward}")

    return np.mean(all_cliprprec)


def main():

    model_path = f"/path/to/clip_idx{STEP_IDX}in10_C1L2_bs64_lr0.001_last_group_acc_1.0000.pt"

    print(f"Loading model from {model_path}...")
    model = PairwiseComparator_C1L2()
    # model = PairwiseComparator_C1L3()
    # model = PairwiseComparator_C0L2()
    # model = PairwiseComparator_C1L2()
    # model = PairwiseComparator_C2L2()
    # model = PairwiseComparator_C3L2()
    # model = PairwiseComparator_C1L2_prompt()
    # model = PairwiseComparator_C1L3_prompt()
    model.load_state_dict(torch.load(model_path, weights_only=True))
    model.eval()
    model = model.cuda()

    score_path = "/path/to/score_results_t3_single_cfg5/"
    score_data = load_score_data(score_path)  # (100, 20)
    score_data_image_reward = load_score_data_image_reward(score_path)  # (100, 20)
    # score_data_aesthetic = load_score_data_aesthetic(base_path) # (100, 20)

    data_path = "/path/to/out_eval_data_create_t3_single_cfg5/"

    latent_data = load_latent_data(data_path)
    # latent_data = torch.load(f"latent_data_t3_single_data100_seed64.pt", weights_only=True).cuda()

    prompt_data = None
    # prompt_data = torch.load("text_embeds_t3_single_data100.pt", weights_only=True)

    winner_idx_list = []
    selected_score = []
    selected_score_image_reward = []
    # selected_score_aesthetic = []
    for data_idx in tqdm.tqdm(range(100)):
        # data_idx = 0
        score = score_data[data_idx]
        latent = latent_data[data_idx]

        score_image_reward = score_data_image_reward[data_idx]
        # score_aesthetic = score_data_aesthetic[data_idx]

        # Create a random permutation of indices
        # perm = np.random.permutation(len(score))
        # # Permute both score and latent with the same order
        # score = score[perm]
        # latent = latent[perm]

        if prompt_data is not None:
            prompt = prompt_data[data_idx].unsqueeze(0).cuda()
        else:
            prompt = None

        # Tournament style elimination to find the best sample
        remaining_indices = list(range(NUM_SEED))

        # Continue tournament rounds until we have a winner
        while len(remaining_indices) > 1:
            next_round = []

            # Pair up remaining samples and compare
            for i in range(0, len(remaining_indices), 2):
                # If odd number remaining, last one automatically advances
                if i + 1 >= len(remaining_indices):
                    next_round.append(remaining_indices[i])
                    continue

                idx1 = remaining_indices[i]
                idx2 = remaining_indices[i + 1]

                # Prepare input pair for model
                lat1 = latent[idx1].unsqueeze(0).cuda()
                lat2 = latent[idx2].unsqueeze(0).cuda()

                # Get model prediction
                with torch.no_grad():
                    if prompt is not None:
                        pred = model(lat1, lat2, prompt)
                    else:
                        pred = model(lat1, lat2)
                    pred = torch.sigmoid(pred)

                # Winner advances to next round
                # ground truth: 1 if score1 > score2, 0 otherwise
                winner = idx1 if pred > 0.5 else idx2
                next_round.append(winner)

            remaining_indices = next_round
            # print(f"Round complete. Remaining candidates: {len(remaining_indices)}")

        # Get final winner
        winner_idx = remaining_indices[0]
        winner_idx_list.append(winner_idx)
        # print(f"\nWinner found:")
        # print(f"Sample index: {winner_idx}")
        # print(f"Sample score: {score[winner_idx]}")

        selected_score.append(score[winner_idx])
        selected_score_image_reward.append(score_image_reward[winner_idx])
        # selected_score_aesthetic.append(score_aesthetic[winner_idx])

    print("average random clip_similarity score: ", np.mean(np.mean(score_data, axis=1)))
    print("average selected clip_similarity score: ", np.mean(selected_score))

    print("\naverage random image_reward score: ", np.mean(np.mean(score_data_image_reward, axis=1)))
    print("average selected image_reward score: ", np.mean(selected_score_image_reward))


def main_baseline():
    NUM_SEED = 1
    STEP = 10
    print(f"NUM_SEED: {NUM_SEED}")

    base_path = "/path/to/score_results_t3_single_cfg7/"

    score_data = []
    for seed_idx in range(NUM_SEED):
        score_path = os.path.join(base_path, f"sd3_num_inference_steps_{STEP}_seed_{seed_idx}/clip_similarity_scores.txt")
        with open(score_path, "r") as f:
            clip_scores = [float(line.strip()) for line in f.readlines()]
        score_data.append(clip_scores)

    score_data = np.array(score_data).transpose()
    print(np.mean(np.mean(score_data, axis=1)))

    score_data = []
    for seed_idx in range(NUM_SEED):
        score_path = os.path.join(base_path, f"sd3_num_inference_steps_{STEP}_seed_{seed_idx}/image_reward_scores.txt")
        with open(score_path, "r") as f:
            clip_scores = [float(line.strip()) for line in f.readlines()]
        score_data.append(clip_scores)

    score_data = np.array(score_data).transpose()
    print(np.mean(np.mean(score_data, axis=1)))


if __name__ == "__main__":
    main()
    # main_baseline()
