import os
import sys
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from absl import app, flags
from torchdiffeq import odeint
from torchdyn.core import NeuralODE

# Move to repository root from this file's directory
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.join(current_dir, '../../../')
sys.path.insert(0, os.path.abspath(root_dir))

from torchcfm.models.unet.unet import UNetModelWrapper

FLAGS = flags.FLAGS
# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Visualization
flags.DEFINE_string("input_dir", "./results", help="output_directory")
flags.DEFINE_string("dataset_path", "./data/cifar10_lt/cifar10-lt-ratio50.npz", help="Path to the .npz dataset file")
flags.DEFINE_string("model", "sinkhorn_otwfm", help="flow matching model type")
flags.DEFINE_string("classifier_model", "resnet56", help="classifier model type")
flags.DEFINE_string("training_params", "r1_tauinf1_inv_u_efm_beta1", help="model hyperparameters")
flags.DEFINE_integer("integration_steps", 100, help="number of inference steps")
flags.DEFINE_string("integration_method", "dopri5", help="integration method to use")
flags.DEFINE_integer("step", 20000, help="training steps")
flags.DEFINE_integer("num_samples", 10000, help="number of samples to generate")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance")
flags.DEFINE_integer("batch_size", 1024, help="Batch size")
flags.DEFINE_boolean("visualize_dataset", True, help="Whether to visualize dataset distribution")
flags.DEFINE_boolean("visualize_generated", False, help="Whether to visualize generated distribution")
flags.DEFINE_boolean("show_true_labels", True, help="Whether to show true labels for dataset")

FLAGS(sys.argv)

# CIFAR10 class names
CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                   'dog', 'frog', 'horse', 'ship', 'truck']

def load_classifier(model_name='resnet56'):
    model_options = {
        'resnet20': ('chenyaofo/pytorch-cifar-models', 'cifar10_resnet20'),
        'resnet32': ('chenyaofo/pytorch-cifar-models', 'cifar10_resnet32'),
        'resnet44': ('chenyaofo/pytorch-cifar-models', 'cifar10_resnet44'),
        'resnet56': ('chenyaofo/pytorch-cifar-models', 'cifar10_resnet56'),
    }
    
    if model_name not in model_options:
        raise ValueError(f"Cannot find {model_name} in {list(model_options.keys())}")
    
    repo, model_type = model_options[model_name]
    model = torch.hub.load(repo, model_type, pretrained=True)
    model.eval()
    return model

def get_predictions(model, dataloader, device):
    predictions = []
    true_labels = []
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.numpy())
    
    return np.array(predictions), np.array(true_labels)

def generate_images(model, num_samples, batch_size, device):
    images = []
    for i in range(0, num_samples, batch_size):
        current_batch = min(batch_size, num_samples - i)
        x = torch.randn(current_batch, 3, 32, 32, device=device)
        
        if FLAGS.integration_method == "euler":
            t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1, device=device)
            traj = model.trajectory(x, t_span=t_span)
        else:
            t_span = torch.linspace(0, 1, 2, device=device)
            traj = odeint(
                model, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method
            )
        
        traj = traj[-1, :]
        img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8)
        images.append(img)
    
    return torch.cat(images, dim=0)

def plot_distribution(predictions, true_labels=None, title="Class Distribution", save_path=None):
    plt.figure(figsize=(15, 6))
    
    x = np.arange(10)  #  
    width = 0.35  #  
    
    #  
    pred_counts = np.bincount(predictions, minlength=10)
    plt.bar(x - width/2, pred_counts, width, alpha=0.8, label=f'Predictions of {FLAGS.classifier_model}')
    
    #    ( )
    if true_labels is not None and FLAGS.show_true_labels:
        true_counts = np.bincount(true_labels, minlength=10)
        plt.bar(x + width/2, true_counts, width, alpha=0.8, label='True Labels')
    
    plt.xticks(x, CIFAR10_CLASSES, rotation=45, fontsize=12)
    plt.title(title, fontsize=16, pad=20)
    plt.xlabel('Class', fontsize=12)
    plt.ylabel('Count', fontsize=12)
    plt.legend(fontsize=14, loc='upper right')
    
    #     
    for i, v in enumerate(pred_counts):
        plt.text(i - width/2, v, str(v), ha='center', va='bottom', fontsize=10)
    if true_labels is not None and FLAGS.show_true_labels:
        for i, v in enumerate(true_counts):
            plt.text(i + width/2, v, str(v), ha='center', va='bottom', fontsize=10)
    
    plt.grid(True, axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


class NPZDataset(torch.utils.data.Dataset):
    def __init__(self, npz_path, train=True, transform=None):
        data = np.load(npz_path)
        print(data.keys())  # [train_data, train_labels, test_data, test_labels]
        
        if train:
            self.images = torch.from_numpy(data['train_data']).float()
            self.labels = torch.from_numpy(data['train_labels']).long()
        else:
            self.images = torch.from_numpy(data['test_data']).float()
            self.labels = torch.from_numpy(data['test_labels']).long()
            
        self.images = self.images.permute(0, 3, 1, 2)  # (N, 32, 32, 3) to (N, 3, 32, 32)
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        else:
            #  
            image = image / 255.0
            image = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(image)
            
        return image, label

def load_npz_dataset(npz_path, train=True):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    dataset = NPZDataset(npz_path, train=train, transform=transform)
    return dataset

def main(argv):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    
    # Classifier 
    classifier = load_classifier(FLAGS.classifier_model).to(device)
    
    #   
    if FLAGS.visualize_dataset:
        dataset = NPZDataset(FLAGS.dataset_path)
        dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=FLAGS.batch_size,
                                                shuffle=False, num_workers=2)
        predictions, true_labels = get_predictions(classifier, dataset_loader, device)
        plot_distribution(predictions, true_labels, 
                         "CIFAR10-LT Dataset Class Distribution",
                         f"{FLAGS.input_dir}/dataset_lt_distribution.png")

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                              download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=FLAGS.batch_size,
                                                shuffle=False, num_workers=2)
        
        predictions, true_labels = get_predictions(classifier, trainloader, device)
        plot_distribution(predictions, true_labels, 
                         "CIFAR10 Dataset Class Distribution",
                         f"{FLAGS.input_dir}/dataset_distribution.png")
    
    #    
    if FLAGS.visualize_generated:
        #   
        new_net = UNetModelWrapper(
            dim=(3, 32, 32),
            num_res_blocks=2,
            num_channels=FLAGS.num_channel,
            channel_mult=[1, 2, 2, 2],
            num_heads=4,
            num_head_channels=64,
            attention_resolutions="16",
            dropout=0.1,
        ).to(device)
        
        PATH = f"{FLAGS.input_dir}/{FLAGS.model}_{FLAGS.training_params}/{FLAGS.model}_cifar10_weights_step_{FLAGS.step}.pt"
        checkpoint = torch.load(PATH, map_location=device)
        state_dict = checkpoint["ema_model"]
        
        try:
            new_net.load_state_dict(state_dict)
        except RuntimeError:
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                new_state_dict[k[7:]] = v
            new_net.load_state_dict(new_state_dict)
        
        new_net.eval()
        
        #  
        generated_images = generate_images(new_net, FLAGS.num_samples, FLAGS.batch_size, device)
        
        #    
        predictions = []
        for i in range(0, len(generated_images), FLAGS.batch_size):
            batch = generated_images[i:i + FLAGS.batch_size].float() / 255.0
            batch = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(batch)
            with torch.no_grad():
                outputs = classifier(batch.to(device))
                preds = outputs.argmax(dim=1)
                predictions.extend(preds.cpu().numpy())
        
        plot_distribution(np.array(predictions), None,
                         "Generated Images Class Distribution",
                         f"{FLAGS.input_dir}/generated_distribution.png")

if __name__ == '__main__':
    app.run(main)



