import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
import wandb
from selector_model import PairwiseComparator_C0L2, PairwiseComparator_C1L2, PairwiseComparator_C1L2_prompt, PairwiseComparator_C1L3, PairwiseComparator_C2L2, PairwiseComparator_C3L2
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import get_cosine_schedule_with_warmup

# all old C1L1 should be C1L2

TOTAL_STEP = 10
STEP_IDX = 1
LR = 1e-3
BATCH_SIZE = 64

# NAME = f"clip_idx{STEP_IDX}in10_C1L3_bs{BATCH_SIZE}_lr{LR}"
# NAME = f"clip_idx{STEP_IDX}in10_C2L2_bs{BATCH_SIZE}_lr{LR}"
# NAME = f"clip_idx{STEP_IDX}in10_C0L2_bs{BATCH_SIZE}_lr{LR}"
NAME = f"clip_idx{STEP_IDX}in{TOTAL_STEP}_C1L2_bs{BATCH_SIZE}_lr{LR}"


def load_score_data():
    base_path = "/path/to/score_results_4o/"
    score_data = []
    for seed_idx in range(64):
        score_path = os.path.join(base_path, f"sd3_num_inference_steps_{TOTAL_STEP}_seed_{seed_idx}/clip_similarity_scores.txt")
        # score_path = os.path.join(base_path, f"sd3_num_inference_steps_10_seed_{seed_idx}/aesthetic_score.txt")
        with open(score_path, "r") as f:
            clip_scores = [float(line.strip()) for line in f.readlines()]
        score_data.append(clip_scores)

    score_data = np.array(score_data).transpose()
    # print(score_data.shape) # (100, 20)
    return score_data


def load_latent_data(step_idx=STEP_IDX):
    base_path = "/path/to/out_eval_data_create_4o/"

    latent_data = []

    print(f"Loading latent data from {base_path}...")
    for data_idx in tqdm.tqdm(range(100)):
        # for data_idx in range(10):
        latent_data_per_data = []
        for seed_idx in range(64):
            # for seed_idx in range(10):
            latent_path = os.path.join(base_path, f"sd3_num_inference_steps_{TOTAL_STEP}_seed_{seed_idx}/gsdiff_gobj83k_sd35m__render/inference/000_{data_idx:03d}_013020_gs_out_all.pt")
            data = torch.load(latent_path, weights_only=True)
            # print(len(data)) # 10
            # print(data[0].shape) # torch.Size([4, 16, 32, 32])
            latent_data_per_data.append(data[step_idx].cpu())
        latent_data.append(torch.stack(latent_data_per_data, dim=0))

    latent_data = torch.stack(latent_data, dim=0)  # torch.Size([100, 20, 4, 16, 32, 32])

    # torch.save(latent_data, "latent_data_4o_data1000_seed64.pt")
    torch.save(latent_data, f"latent_data_4o_data100_seed64_step{STEP_IDX}in{TOTAL_STEP}.pt")
    return latent_data


def process_latent_data(latent_data, score_data, threshold=0.00, is_test=False):
    num_data, num_seeds = latent_data.shape[0], latent_data.shape[1]

    # Calculate number of valid pairs for pre-allocation
    pairs_per_sample = sum(1 for i in range(num_seeds) for j in range(i + 1, num_seeds) if score_data[0][i] != score_data[0][j])
    total_pairs = num_data * pairs_per_sample

    # Build tensors by appending
    lat1_tensors = []
    lat2_tensors = []
    label_tensors = []
    label_01_tensors = []

    for data_idx in range(len(latent_data)):
        latent_block = latent_data[data_idx]  # [20, 4, 16, 32, 32]
        score_block = score_data[data_idx]  # [20]

        for i in range(num_seeds):
            for j in range(i + 1, num_seeds):
                score1 = score_block[i].item()
                score2 = score_block[j].item()

                if abs(score1 - score2) < threshold and not is_test:
                    continue
                # print(f"score1: {score1}, score2: {score2}, diff: {abs(score2 - score1)}")

                lat1_tensors.append(latent_block[i])
                lat2_tensors.append(latent_block[j])
                label_tensors.append(torch.tensor(score1 - score2, dtype=torch.float32))
                label_01_tensors.append(torch.tensor(int(score1 > score2), dtype=torch.float32))  # 1 if score1 > score2, 0 otherwise
    # Stack the lists into tensors
    lat1_tensors = torch.stack(lat1_tensors)
    lat2_tensors = torch.stack(lat2_tensors)
    label_tensors = torch.stack(label_tensors)
    label_01_tensors = torch.stack(label_01_tensors)

    # if not is_test:
    lat1_tensors, lat2_tensors, label_tensors, label_01_tensors = process_data_balance(lat1_tensors, lat2_tensors, label_tensors, label_01_tensors)
    # exit()

    # Create dataset and dataloader
    from torch.utils.data import DataLoader, TensorDataset

    dataset = TensorDataset(lat1_tensors, lat2_tensors, label_tensors, label_01_tensors)

    if is_test:
        loader = DataLoader(dataset, batch_size=256, shuffle=False)
    else:
        loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    print(f"Used pairs/total pairs: {len(dataset)}/{total_pairs}")

    return loader


def process_data_balance(lat1_tensors, lat2_tensors, label_tensors, label_01_tensors):
    # Create bins for score differences
    bins = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, float("inf")]  # clip score
    # bins = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, float('inf')] # aesthetic score
    bin_counts = [0] * (len(bins) - 1)
    max_per_group = 100  # Maximum samples per group

    # Convert label_tensors to numpy for easier processing
    labels = label_tensors.abs().numpy()

    # Set random seed for reproducibility
    np.random.seed(42)

    # Initialize lists to store selected indices for each bin
    selected_indices = []

    # Sample indices for each bin
    for i in range(len(bins) - 1):
        # Get indices for current bin
        bin_mask = (labels >= bins[i]) & (labels < bins[i + 1])
        bin_indices = np.where(bin_mask)[0]

        # Count samples in bin
        bin_counts[i] = len(bin_indices)

        # Randomly sample up to max_per_group indices
        if len(bin_indices) > max_per_group:
            sampled_indices = np.random.choice(bin_indices, max_per_group, replace=False)
        else:
            sampled_indices = bin_indices

        selected_indices.extend(sampled_indices)

    # Filter tensors based on selected indices
    lat1_tensors = lat1_tensors[selected_indices]
    lat2_tensors = lat2_tensors[selected_indices]
    label_tensors = label_tensors[selected_indices]
    label_01_tensors = label_01_tensors[selected_indices]

    # Print group counts after balancing
    print("Group counts after balancing:")
    bins = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, float("inf")]
    labels_after = label_tensors.abs().numpy()

    for i in range(len(bins) - 1):
        bin_mask = (labels_after >= bins[i]) & (labels_after < bins[i + 1])
        count = np.sum(bin_mask)
        print(f"[{bins[i]:.2f}-{bins[i+1] if bins[i+1]!=float('inf') else 'inf'}]: {count}")

    return lat1_tensors, lat2_tensors, label_tensors, label_01_tensors


def train_binary_pairwise(model, train_loader, test_loader, num_epochs=20, device="cuda"):

    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-2)

    # Calculate total training steps
    num_training_steps = len(train_loader) * num_epochs
    # num_warmup_steps = num_training_steps // 10  # 10% of total steps for warmup
    num_warmup_steps = 0
    # Create scheduler with warmup + cosine decay
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)

    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        model.train()
        total_loss, total_count, total_correct = 0, 0, 0
        all_label = []
        all_label_01 = []
        all_pred = []

        for lat1, lat2, label, label_01 in train_loader:
            lat1 = lat1.to(device)
            lat2 = lat2.to(device)
            label = label.to(device)
            label_01 = label_01.to(device)
            # print(lat1.shape) # torch.Size([32, 4, 16, 32, 32])
            # print(lat2.shape) # torch.Size([32, 4, 16, 32, 32])
            # print(label.shape) # torch.Size([32])
            # break

            pred = model(lat1, lat2)
            # print(pred.shape) # torch.Size([32])
            # exit()

            all_label.extend(label.cpu().numpy())
            all_label_01.extend(label_01.cpu().numpy())
            all_pred.extend((pred > 0.5).cpu().numpy())

            # loss = criterion(pred, label_01)

            # BCE loss w=1+α⋅abs(score1−score2)
            # w = 1 + 10. * abs(10*label)**2
            # w = torch.exp(10. * abs(label))
            # w = torch.ones_like(label)

            # if max(w) > 1.1:
            #     print(w)
            #     exit()
            # loss = F.binary_cross_entropy_with_logits(pred, label_01, weight=w)
            loss = criterion(pred, label_01)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()
            total_count += label.size(0)
            total_correct += ((pred > 0.5) == label_01).sum().item()

        acc = total_correct / total_count
        avg_loss = total_loss / total_count
        print(f"\n[Epoch {epoch}] Train Avg Loss: \t{avg_loss:.4f}, Train Acc: \t{acc:.4f}")  # , LR: {scheduler.get_last_lr()[0]:.6f}")

        # group_acc(all_label, all_label_01, all_pred)

        # Test
        test_loss, test_count, test_correct = 0, 0, 0
        all_label = []
        all_label_01 = []
        all_pred = []
        model.eval()
        for lat1, lat2, label, label_01 in test_loader:
            lat1 = lat1.to(device)
            lat2 = lat2.to(device)
            label = label.to(device)
            label_01 = label_01.to(device)

            pred = model(lat1, lat2)

            all_label.extend(label.cpu().numpy())
            all_label_01.extend(label_01.cpu().numpy())
            all_pred.extend((pred > 0.5).cpu().numpy())

            loss = criterion(pred, label_01)

            test_loss += loss.item()
            test_count += label.size(0)
            test_correct += ((pred > 0.5) == label_01).sum().item()

        test_acc = test_correct / test_count
        test_avg_loss = test_loss / test_count
        print(f"[Epoch {epoch}] Test Avg Loss: {test_avg_loss:.4f}, Test Acc: {test_acc:.4f}")

        last_group_acc = group_acc(all_label, all_label_01, all_pred)
        # exit()

        # Log metrics to wandb
        wandb.log(
            {"epoch": epoch, "train_loss": avg_loss, "train_acc": acc, "test_loss": test_avg_loss, "test_acc": test_acc, "learning_rate": scheduler.get_last_lr()[0], "last_group_acc": last_group_acc}
        )

        # Save model if test accuracy on last group improves
    model_save_path = f"verifier_model_{STEP_IDX}in{TOTAL_STEP}/{NAME}_last_group_acc_{last_group_acc:.4f}.pt"
    os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
    torch.save(model.state_dict(), model_save_path)


def group_acc(all_label, all_label_01, all_pred):
    # Convert lists to numpy arrays
    all_label = np.array(all_label)
    all_label_01 = np.array(all_label_01)

    # Calculate and print the ratio of 0s and 1s
    num_zeros = np.sum(all_label_01 == 0)
    num_ones = np.sum(all_label_01 == 1)
    total = len(all_label_01)

    print("Label distribution:")
    print(f"0s: {num_zeros}/{total} ({num_zeros/total*100:.2f}%)")
    print(f"1s: {num_ones}/{total} ({num_ones/total*100:.2f}%)")

    all_pred = np.array(all_pred)

    # Define distance thresholds
    thresholds = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]

    # Initialize counters for each distance group
    group_correct = np.zeros(11)
    group_total = np.zeros(11)

    # Get distance group indices for each sample
    def get_group_idx(distances):
        indices = np.zeros_like(distances, dtype=int)
        # First 10 groups: [0-0.01], [0.01-0.02], ..., [0.09-0.1]
        for i in range(len(thresholds)):
            if i == 0:
                mask = np.abs(distances) <= thresholds[i]
            else:
                mask = (np.abs(distances) > thresholds[i - 1]) & (np.abs(distances) <= thresholds[i])
            indices[mask] = i
        # Last group: [0.1, inf]
        indices[np.abs(distances) > thresholds[-1]] = 10
        return indices

    # Get true distance group indices
    true_groups = get_group_idx(all_label)

    # Count correct predictions for each distance group
    for i in range(len(all_label)):
        group = true_groups[i]
        pred = all_pred[i]
        target = all_label_01[i]  # Use the actual binary label from the data

        group_total[group] += 1
        if pred == target:
            group_correct[group] += 1

    # Print results in 4 lines
    ranges = ["[0-0.01]", "[0.01-0.02]", "[0.02-0.03]", "[0.03-0.04]", "[0.04-0.05]", "[0.05-0.06]", "[0.06-0.07]", "[0.07-0.08]", "[0.08-0.09]", "[0.09-0.1]", "[>0.1]"]
    print("Group ranges:\t" + "\t".join(f"{r:>10}" for r in ranges))
    print("Total samples:\t" + "\t".join(f"{n:>10d}" for n in group_total.astype(int)))
    print("Correct pred:\t" + "\t".join(f"{n:>10d}" for n in group_correct.astype(int)))
    print("Accuracy:\t" + "\t".join(f"{acc:>10.4f}" for acc in np.divide(group_correct, group_total, out=np.zeros_like(group_correct), where=group_total != 0)))
    return group_correct[-1] / group_total[-1]  # return the accuracy of the last group


if __name__ == "__main__":

    wandb.init(project="latent-verifier", name=NAME)
    # latent_data = load_latent_data()
    # latent_data = torch.load("latent_data_4o_data100_seed64.pt", weights_only=True)
    latent_data_path = f"latent_data_4o_data100_seed64_step{STEP_IDX}in{TOTAL_STEP}.pt"
    if not os.path.exists(latent_data_path):
        latent_data = load_latent_data()
    else:
        latent_data = torch.load(latent_data_path, weights_only=True)
    print(latent_data.shape)  # torch.Size([100, 64, 4, 16, 32, 32])
    # exit()
    score_data = load_score_data()

    split_idx = 70

    train_loader = process_latent_data(latent_data[:split_idx], score_data[:split_idx])
    test_loader = process_latent_data(latent_data[split_idx:], score_data[split_idx:], is_test=True)

    # model = PairwiseComparator_no_encoder()
    model = PairwiseComparator_C1L2()
    # model = PairwiseComparator_C1L3()
    # model = PairwiseComparator_C2L2()
    # model = PairwiseComparator_C3L2()
    train_binary_pairwise(model, train_loader, test_loader)

    wandb.finish()
