import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision import models
from PIL import Image
import csv
import os
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
# from model.TestDeFnetV1 import GoldenRoad
# from model.TestDeFnetV3 import GoldenRoad
# from model.SimCLRTransfer import GoldenRoad
from model.CDQAE import GoldenRoad
# from model.Resnet import ResNet50
from torchvision.transforms import RandomApply
from torch.utils import data
from data.standard_data import StandardData,ValStandardData
import argparse
# shape3ds dimension traversal
BATCH_SIZE = 1
LEARNING_RATE = 1e-3
EPOCHS = 100
WEIGHT_DECAY = 1e-4
LABEL_SMOOTHING = 0.1
NUM_CLASSES = 10  # Replace with the number of classes in your dataset
parser = argparse.ArgumentParser(description='PyTorch FixMatch Training')
parser.add_argument('--pretrained_path', default='/home/star/Data/g1/yxh/yy/information_logs/CDQAE/version_16/checkpoints/best-epoch=109-dci=0.902839-cosine=0.000109-MIG_discrete_mig=0.719538-modularity_score=0.886517-explicitness_score_test=0.975832.ckpt', type=str,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
parser.add_argument('--log_dir', default='/home/star/Projects/g2/gyh/gyh/CDQAE-main/exploreDisentangle/3dshape', type=str,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
args = parser.parse_args()

# Paths
os.makedirs(args.log_dir, exist_ok=True)
VAL_CSV = '/home/star/Projects/g2/gyh/gyh/DataCSV/3dshape/3dshapeWhole.csv'  # Replace with your validation CSV path
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Dataset and DataLoader
val_dataset = ValStandardData(csv_file=VAL_CSV, train=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Model
model = GoldenRoad().to(device)

checkpoint = torch.load(args.pretrained_path, map_location=device)

if 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
else:
    state_dict = checkpoint  

from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    if k.startswith('model.'):
        k = k[len('model.'):]
    new_state_dict[k] = v

model.load_state_dict(new_state_dict)
model.to(device)

# Loss function (CrossEntropyLoss with label smoothing)
def loss_fn(outputs, labels):
    epsilon = LABEL_SMOOTHING
    num_classes = outputs.size(1)
    
    log_preds = F.log_softmax(outputs, dim=1)
    with torch.no_grad():
        true_dist = torch.zeros_like(log_preds)
        true_dist.fill_(epsilon / (num_classes - 1))
        true_dist.scatter_(1, labels.data.unsqueeze(1), 1 - epsilon)
    return torch.mean(torch.sum(-true_dist * log_preds, dim=1))

def generate_joint_latent_grid(model, z_base, device, save_dir, dim1=8, dim2=9,
                                range1=(0.2, 0.5), range2=(0.0, 0.8), steps=6):

    os.makedirs(save_dir, exist_ok=True)
    to_pil = T.ToPILImage()

    v1_range = torch.linspace(range1[0], range1[1], steps).to(device)
    v2_range = torch.linspace(range2[0], range2[1], steps).to(device)

    z_fixed = z_base.clone().to(device)
    grid_imgs = []

    for i, v1 in enumerate(v1_range):
        row_imgs = []
        for j, v2 in enumerate(v2_range):
            z_mod = z_fixed.clone()
            z_mod[0, dim1] = v1
            z_mod[0, dim2] = v2
            z_mod_unsq = z_mod.unsqueeze(-1).unsqueeze(-1)
            z_rec = model.deconv(z_mod_unsq)
            x_gen = model.decoder(z_rec).clamp(0, 1)
            img = to_pil(x_gen[0].cpu())
            row_imgs.append(img)
        grid_imgs.append(row_imgs)

    fig, axes = plt.subplots(steps, steps, figsize=(12, 12))
    for i in range(steps):
        for j in range(steps):
            axes[i, j].imshow(grid_imgs[i][j])
            axes[i, j].axis('off')
            if i == 0:
                axes[i, j].set_title(f"{range2[0]+j*(range2[1]-range2[0])/steps:.2f}")
            if j == 0:
                axes[i, j].set_ylabel(f"{range1[0]+i*(range1[1]-range1[0])/steps:.2f}", rotation=90)

    fig.suptitle(f"Joint Latent Traversal: z_{dim1} vs z_{dim2}", fontsize=16)
    plt.tight_layout()
    pdf_path = os.path.join(save_dir, f"z{dim1}_z{dim2}_joint.pdf")
    plt.savefig(pdf_path)
    plt.close(fig)

    print(f"[✓] Saved joint latent PDF at {pdf_path}")
def validate_epoch(model, dataloader, device, log_dir):
    import matplotlib.pyplot as plt
    import numpy as np
    import torchvision.transforms as T
    import os
    from tqdm import tqdm
    from PIL import Image
    import torch

    with torch.no_grad():
        # value range
        dim_ranges = {
            0: (0.0, 0.24),
            1: (0.15, 1.8),
            2: (0.0, 1.7),
            3: (0.3, 2.4),
            6: (0.0, 2.5),
            8: (0.2, 0.5),
            9: (0.0, 0.8),
        }

        to_pil = T.ToPILImage()
        loop = tqdm(dataloader, desc="Validating", leave=False)

        for batch_idx, (images1, images2, labels) in enumerate(loop):
            x1 = images1.to(device)
            x2 = images2.to(device)

            z_quant, _, _, _ = model([x1, x2], 5)
            if z_quant.dim() == 1:
                z_quant = z_quant.unsqueeze(0)

            b, latent_dim = z_quant.shape
            save_dir = f"{log_dir}/{batch_idx}/pdf"
            os.makedirs(save_dir, exist_ok=True)

            for dim, (v_min, v_max) in dim_ranges.items():
                if dim >= latent_dim:
                    print(f"Warning: latent dim {dim} exceeds actual z shape {latent_dim}")
                    continue

                z_base = z_quant.clone()
                interp_values = torch.linspace(v_min, v_max, steps=10).to(device)
                frames = []

                for v in interp_values:
                    z_mod = z_base.clone()
                    z_mod[0, dim] = v
                    z_mod_unsq = z_mod.unsqueeze(-1).unsqueeze(-1)  # [B, D, 1, 1]
                    z_rec = model.deconv(z_mod_unsq)
                    x_gen = model.decoder(z_rec).clamp(0, 1)

                    img = to_pil(x_gen[0].cpu())
                    frames.append(img)

                # save to pdf
                fig, axes = plt.subplots(1, 10, figsize=(20, 2))
                for idx, frame in enumerate(frames):
                    axes[idx].imshow(frame)
                    axes[idx].axis('off')
                    axes[idx].set_title(f"{interp_values[idx]:.2f}", fontsize=8)

                plt.tight_layout()
                pdf_path = f"{save_dir}/z_{dim}.pdf"
                plt.savefig(pdf_path)
                plt.close(fig)
                print(f"Saved latent interpolation PDF for z_{dim} at {pdf_path}")
                



validate_epoch(model, val_loader, device,args.log_dir)
