import matplotlib.pyplot as plt
import torch
import os
from models.model import Net, cifar10_small, cifar10_big, mnist_config, mnist_patching_conf, cifar10_small_config, \
    GTSRBCNN, gtsrb_config, TaxiNetCNN, taxinet_config
from train import train_model

## todo should logically be placed somewhere else
cifar10_class_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

gtsrb_class_labels = ["Speed limit (20 km/h)", "Speed limit (30 km/h)", "Speed limit (50 km/h)", "Speed limit (60 km/h)",
                      "Speed limit (70 km/h)", "Speed limit (80 km/h)", "End of speed limit (80 km/h)", "Speed limit (100 km/h)",
                      "Speed limit (120 km/h)", "No passing", "No passing for vehicles > 3.5 t",
                      "Right-of-way at the next intersection", "Priority road", "Yield", "Stop", "No vehicles",
                      "Vehicles > 3.5 t prohibited", "No entry", "General caution", "Dangerous curve to the left",
                      "Dangerous curve to the right", "Double curve", "Bumpy road", "Slippery road", "Road narrows on the right",
                      "Road work", "Traffic signals", "Pedestrians", "Children crossing", "Bicycles crossing", "Beware of ice/snow",
                      "Wild animals crossing", "End of all speed and passing limits", "Turn right ahead", "Turn left ahead",
                      "Ahead only", "Go straight or right", "Go straight or left", "Keep right", "Keep left", "Roundabout mandatory",
                      "End of no passing", "End of no passing, vehicles > 3.5 t"]


def create_mask_from_comps(dataset, comps):
    """Create a mask from components for patch verification."""
    if dataset == 'mnist':
        Z_mask = torch.zeros(mnist_patching_conf['total_neurons'])  # Create a mask tensor based on active components
        for name, neuron_idx in comps.get('dead', []):
            Z_mask[mnist_patching_conf['layer_offset'][name] + neuron_idx] = 1.0
        return Z_mask
    elif dataset.startswith("cifar10"):  # CIFAR10 - ONLY channel-wise patching is supported
        # For CIFAR10, create a dictionary-based mask
        Z_mask = {}

        for layer_obj in cifar10_small_config['conv_layers']:
            Z_mask[layer_obj['name']] = [0.0] * layer_obj['conv_channels']

        # Set dead components to 1 (patch)
        for layer_name, filter_idx  in comps.get('dead', []):
            if layer_name in Z_mask:
                Z_mask[layer_name][filter_idx] = 1.0

        return Z_mask
    elif dataset == 'taxinet':
        Z_mask = {}
        for layer_obj in taxinet_config['conv_layers']:
            Z_mask[layer_obj['name']] = [0.0] * layer_obj['conv_channels']
        for layer_name, filter_idx in comps.get('dead', []):
            if layer_name in Z_mask:
                Z_mask[layer_name][filter_idx] = 1.0
        return Z_mask
    elif dataset == 'gtsrb':
        Z_mask = {}
        for layer_obj in gtsrb_config['conv_layers']:
            Z_mask[layer_obj['name']] = [0.0] * layer_obj['conv_channels']
        for layer_name, filter_idx in comps.get('dead', []):
            if layer_name in Z_mask:
                Z_mask[layer_name][filter_idx] = 1.0
        return Z_mask
    else:
        raise ValueError(f"Unsupported dataset for mask creation: {dataset}")

def create_comps_from_mask(dataset, Z_mask):
    """
    Inverse of create_mask_from_comps.
    Returns both 'dead' (masked == 1) and 'active' (masked == 0) components.
    """
    comps = {'dead': [], 'active': [], "granularity": "neurons" if dataset == 'mnist' else "conv_channels"}

    if dataset == 'mnist':
        offsets = mnist_patching_conf['layer_offset']
        ordered = sorted(offsets.items(), key=lambda x: x[1])
        total = mnist_patching_conf['total_neurons']
        for i, (layer_name, start) in enumerate(ordered):
            end = ordered[i + 1][1] if i + 1 < len(ordered) else total
            for flat_idx in range(start, end):
                local_idx = flat_idx - start
                target = 'dead' if Z_mask[flat_idx] == 1 else 'active'
                comps[target].append((layer_name, local_idx))
        return comps

    # Channel-wise datasets (dictionary mask structure)
    if dataset.startswith('cifar10') or dataset == 'gtsrb':
        for layer_name, channel_flags in Z_mask.items():
            for ch_idx, v in enumerate(channel_flags):
                target = 'dead' if float(v) == 1.0 else 'active'
                comps[target].append((layer_name, ch_idx))
        return comps

    raise ValueError(f"Unsupported dataset for comps creation: {dataset}")

def unnormalize_cifar10(img):
    std = torch.tensor([0.2023, 0.1994, 0.2010], device=img.device).view(3, 1, 1)
    mean = torch.tensor([0.4914, 0.4822, 0.4465], device=img.device).view(3, 1, 1)

    img = img * std + mean  # unnormalize the image
    return img.clamp(0, 1)

def unnormalize_gtsrb(img):
    # These are example values, please replace with actual dataset stats if available
    std = torch.tensor([0.229, 0.224, 0.225], device=img.device).view(3, 1, 1)
    mean = torch.tensor([0.485, 0.456, 0.406], device=img.device).view(3, 1, 1)
    img = img * std + mean
    return img.clamp(0, 1)

def evaluate_model(model, test_gen, test_data, device, dataset='mnist', save_plot_path='pred.png'):

    if dataset == 'taxinet':
        raise NotImplementedError("Evaluation plotting for TaxiNet is not implemented.")
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0

        for images, labels in test_gen:
            if dataset == 'mnist':
                images = images.view(-1, 28 * 28).to(device)  # Flatten
            else:
                images = images.to(device)
            labels = labels.to(device)
            output = model(images)
            _, predicted = torch.max(output, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = (100 * correct) / total
    print(f'Accuracy of the model: {accuracy:.3f} %')

    # Select an image for inference
    # index = random.randint(0, len(test_data) - 1)
    index = 67
    image, label = test_data[index]
    label = label.item() if isinstance(label, torch.Tensor) else label

    image_for_model = image.unsqueeze(0).to(device)
    if dataset == 'mnist':
        image_for_model = image_for_model.view(-1, 28 * 28)  # Flatten for MNIST

    with torch.no_grad():
        output = model(image_for_model)
        predicted_label_index = torch.argmax(output, dim=1).item()

    # Plotting
    plt.figure(figsize=(4, 4))

    if dataset == 'mnist':
        image_np = image.numpy().reshape(28, 28)
        plt.imshow(image_np, cmap='gray_r')
        title = f"Predicted: {predicted_label_index} | Gold: {label}"
    else:  # CIFAR-10 or GTSRB
        if dataset.startswith('cifar10'):
            image = unnormalize_cifar10(image)
            pred_label = cifar10_class_labels[predicted_label_index]
            gold_label = cifar10_class_labels[label]
        elif dataset == 'gtsrb':
            image = unnormalize_gtsrb(image)
            pred_label = gtsrb_class_labels[predicted_label_index]
            gold_label = gtsrb_class_labels[label]
        else:
            raise ValueError(f"Unsupported dataset for evaluation plotting: {dataset}")

        image_np = image.permute(1, 2, 0).cpu().numpy()
        plt.imshow(image_np, interpolation='bilinear')
        title = f"Predicted: {pred_label} | Gold: {gold_label}"

    plt.title(title, fontsize=12, pad=10, color='blue')
    plt.axis('off')
    plt.savefig(save_plot_path, bbox_inches='tight')
    plt.close()
    print(f"Prediction visualization saved to {save_plot_path}")

def load_taxinet_model(model_path, device):
    """Loads the TaxiNet model with pre-trained weights."""
    model = TaxiNetCNN().to(device) # The state dict is for the model itself, not a checkpoint dictionary
    model.load_state_dict(torch.load(model_path, weights_only=False, map_location=device))
    model.eval()
    return model

def load_gtsrb_model(model_path, device):
    """Loads the GTSRB model with pre-trained weights."""
    model = GTSRBCNN().to(device) # The state dict is for the model itself, not a checkpoint dictionary
    model.load_state_dict(torch.load(model_path, weights_only=False, map_location=device))
    model.eval()
    return model

def create_and_train_model(test_gen, train_gen, test_data, device, model_path=''):
    net = Net(mnist_config['input_size'], mnist_config['hidden_size_1'], mnist_config['hidden_size_2'], mnist_config['num_classes']).to(device)
    train_model(net, train_gen, test_gen, device, mnist_config['num_epochs'], mnist_config['lr'])
    evaluate_model(net, test_gen, test_data, device)
    if model_path:
        torch.save(net.state_dict(), model_path)
        print(f"model saved to {model_path}")
    return net

def load_mnist_model(model_path, device, **kwargs):
    net = Net(mnist_config['input_size'], mnist_config['hidden_size_1'], mnist_config['hidden_size_2'], mnist_config['num_classes']).to(device)
    net.load_state_dict(torch.load(model_path, weights_only=False, map_location=device))
    net.eval()
    return net


def load_cifar10_model(model_path, device, model_type):
    # select model by dataset name
    if model_type == "cifar10-small":
        model = cifar10_small().to(device)
        ckpt = torch.load(model_path, map_location=device, weights_only=False)
        model.load_state_dict(ckpt.get("state_dict", ckpt))
        model.eval()
        return model
    elif model_type == 'cifar10-big':
        model = cifar10_big().to(device)
        if not os.path.exists(model_path):
            print(f"Weights not found at {model_path}. Downloading...")
            pretrained_model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)
            torch.save(pretrained_model.state_dict(), model_path)
            print(f"Weights downloaded and saved to {model_path}.")
        else:
            print(f"Weights found at {model_path}. Loading...")

        model.load_state_dict(torch.load(model_path, weights_only=False, map_location=device))
        model.eval()
        return model


def load_and_evaluate_mnist_model(model_path, test_gen, test_data, device):
    net = load_mnist_model(model_path, device)
    evaluate_model(net, test_gen, test_data, device, dataset='mnist', save_plot_path='mnist_pred.png')
    return net

def load_and_evaluate_cifar10_model(test_gen, test_data, model_path, device, model_type='cifar10-small'):
    resnet20_model = load_cifar10_model(model_path, device, model_type)
    evaluate_model(resnet20_model, test_gen, test_data, device, dataset='cifar10', save_plot_path='cifar10_pred.png')
    return resnet20_model

def compare_model_predictions(formal_pruned_model_path, informal_pruned_model_path, x, device, dataset, save_plot_path, save_stats_path):
    # load models from paths
    if dataset == 'mnist':
        formal_pruned_model = load_mnist_model(formal_pruned_model_path, device)
        informal_pruned_model = load_mnist_model(informal_pruned_model_path, device)
    elif dataset in ['cifar10', 'cifar10-small', 'cifar10-big']:
        formal_pruned_model = load_cifar10_model(formal_pruned_model_path, device, dataset)
        informal_pruned_model = load_cifar10_model(informal_pruned_model_path, device, dataset)
    elif dataset == 'gtsrb':
        formal_pruned_model = load_gtsrb_model(formal_pruned_model_path, device)
        informal_pruned_model = load_gtsrb_model(informal_pruned_model_path, device)
    else:
        raise ValueError(f"Unsupported dataset type: {dataset}")

    formal_pruned_model.eval()
    informal_pruned_model.eval()

    image = x

    image_for_model = image.unsqueeze(0).to(device)
    if dataset == 'mnist':
        image_for_model = image_for_model.view(-1, 28 * 28)  # Flatten for MNIST

    # Get logits from both models
    with torch.no_grad():
        logits1 = formal_pruned_model(image_for_model)
        logits2 = informal_pruned_model(image_for_model)

    # Get predictions
    pred1 = torch.argmax(logits1, dim=1).item()
    pred2 = torch.argmax(logits2, dim=1).item()

    # Plot the sample
    plt.figure(figsize=(4, 4))
    if dataset == 'mnist':
        image_np = image.numpy().reshape(28, 28)
        plt.imshow(image_np, cmap='gray_r')
        title = f"Formal: {pred1} | Informal: {pred2}"
    else:  # CIFAR-10 or GTSRB
        if dataset.startswith('cifar10'):
            image = unnormalize_cifar10(image)
            pred1_label = cifar10_class_labels[pred1]
            pred2_label = cifar10_class_labels[pred2]
        elif dataset == 'gtsrb':
            image = unnormalize_gtsrb(image)
            pred1_label = gtsrb_class_labels[pred1]
            pred2_label = gtsrb_class_labels[pred2]
        else:
            raise ValueError(f"Unsupported dataset for comparison plotting: {dataset}")

        image_np = image.permute(1, 2, 0).cpu().numpy()
        plt.imshow(image_np, interpolation='bilinear')
        title = f"Formal: {pred1_label} | Informal: {pred2_label}"

    plt.title(title, fontsize=12, pad=10, color='blue')
    plt.axis('off')
    plt.savefig(save_plot_path, bbox_inches='tight')
    plt.close()
    print(f"Comparison visualization saved to {save_plot_path}")

    # log logits to the stats file
    with open(save_stats_path, 'w') as f:
        f.write(f"Formal Model Logits: {logits1.cpu().numpy().tolist()}\n")
        f.write(f"Informal Model Logits: {logits2.cpu().numpy().tolist()}\n")
    print(f"Logits saved to {save_stats_path}")

    return logits1.cpu().numpy().tolist(), logits2.cpu().numpy().tolist()