import time

import numpy as np
from torch import nn, optim
from labels import get_class_label
from src.iees_utils import generate_pfams, calculate_consistency_index, compute_iees_score, compute_iees_score1
from src.pfam_utils import ExitAttributionCache, generate_pfams_after_exit, generate_pfams_after_exit1
from src.visuals import plot_combined_image, visualize_xai
from src.xai import getAttribution
from util import load_model, reorganize_attribution_maps, generate_thresholds, getCumulativeFlops
import torch
import torch.nn.functional as F


def load_and_train_model(model_name,num_c,num_exits,device,deployment,num_epochs,trainloader,testloader):
    model=load_model(model_name,num_c,num_exits).to(device)
    if not deployment:
        print("Training started!")
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)  # Learning rate scheduler
        for epoch in range(1, num_epochs + 1):
            print(f"Epoch {epoch}/{num_epochs}")
            train(model, trainloader, criterion, optimizer, scheduler, device, epoch,[1.0, 0.7, 0.5, 0.3])
            torch.save(model.state_dict(), f"../models/{model_name}.pth")
    else:
        model.load_state_dict(torch.load(f"../models/{model_name}.pth", weights_only=True))

    return model


def train(model, dataloader, criterion, optimizer, scheduler, device, epoch, loss_weights):
    model.train()
    running_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass through each exit
        outputs = model(inputs)

        # Initialize total loss
        total_loss = 0

        # Calculate weighted loss for each exit
        for i, exit_output in enumerate(outputs):
            loss = criterion(exit_output, labels)
            total_loss += loss_weights[i] * loss  # Apply weight for each exit

        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()

        # Print loss every 100 batches
        if batch_idx % 100 == 99:
            print(f"Epoch [{epoch}], Batch [{batch_idx + 1}], Loss: {running_loss / 100:.4f}")
            running_loss = 0.0

    # Step the learning rate scheduler
    scheduler.step()


def inference(model, proxy, testloader, device, default_thresholds, dataset_name, output_dir,weights=None):
    model.eval()

    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)

        for i in range(inputs.size(0)):
            input_image = inputs[i:i + 1]
            label = labels[i].item()

            # IMPORTANT: Ensure the input is set for gradient tracking later inside PFAMs, not here.
            input_image_np = input_image.cpu().detach().squeeze().permute(1, 2, 0).numpy()

            # Initialize caches
            IEEscores = []
            cache = ExitAttributionCache()
            Confidence_scores = []
            predicted_classes = []
            cumulative_maps=[]
            exit_idx = -1
            pf_list=[]
            output_logits_previous = None

            # Phase 1 & 2: Forward-only inference + IEES decision
            for idx in range(model.num_exits + 1):
                print("processing***********************************:",idx)
                st=time.time()
                logits, class_idx, _ = model.forward_to_exit(input_image, idx)

                cache.cache(idx, model, logits, class_idx)

                # Cache for PFAM (if you still need it)

                predicted_classes.append(class_idx.item())
                output = F.softmax(logits, dim=1)




                conf = output.max(dim=1).values.detach().cpu().numpy()
                topk = torch.topk(output, 2, dim=1).values
                margin = (topk[:, 0] - topk[:, 1]).detach().cpu().numpy()
                entropy = (-output * torch.log(output + 1e-8)).sum(dim=1).detach().cpu().numpy()

                act_mean = model.activations.mean(dim=(1, 2, 3)).detach().cpu().numpy()
                act_max = model.activations.amax(dim=(1, 2, 3)).detach().cpu().numpy()
                feature_vector = [conf, margin, entropy, act_mean, act_max]
                feature_vector = np.array(feature_vector, dtype=np.float32).flatten().reshape(1, -1)



                # pf_list, cumulative_map = generate_pfams(model.activations, model.gradients, input_image,
                #                                             pf_list)
                #cumulative_maps.append(cumulative_map)
                #consistency_index = calculate_consistency_index(cumulative_maps)
                # ci = consistency_index
                # iees_score, confidence, gwtedAct, activation_score, gradient_score, consistency_index = \
                #     compute_iees_score(model.activations, model.gradients, output, consistency_index, weights)


                predicted_iees = proxy.predict(feature_vector)[0]

                #print("iees_score,predicted_iees,iees_score1", iees_score,predicted_iees)
                IEEscores.append(predicted_iees)
                Confidence_scores.append(conf.item())




                print(idx,model.num_exits,predicted_iees,">=",default_thresholds,idx)
                if idx< len(default_thresholds) and (predicted_iees >= default_thresholds[idx]):
                    print(default_thresholds[idx])
                    exit_idx = idx
                    break

            # Ensure exit_idx is valid
            n_exits = model.num_exits if exit_idx == -1 else exit_idx

            #pfams_lists,cumulative_mapss=generate_pfams_after_exit1(model, input_image, n_exits)
            pfams_list, cumulative_map = generate_pfams_after_exit(
                model,
                input_image,
                n_exits,
                cache.activations,
                cache.outputs,
                cache.class_ids
            )
            #print("pfams_list", pfams_list)
            #print("cumulative_map", cumulative_map)

            filename = (
                f"{output_dir}/pfam/exit_final.png"
                if exit_idx == model.num_exits
                else f"{output_dir}/pfam/exit_{exit_idx}.png"
            )
            #print("========================================conf", Confidence_scores)
            plot_combined_image(
                input_image_np,
                pfams_list,
                cumulative_map,
                filename,
                get_class_label(dataset_name, predicted_classes),
                get_class_label(dataset_name, label),
                IEEscores,
                Confidence_scores,
            )

def comparision_with_XAI_tools(model,dataset_name,testloader,device,default_thresholds,weights):
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        for i in range(inputs.size(0)):
            input_image = inputs[i:i + 1]
            label = labels[i].item()
            pfams_list, cumulative_map_fpam, cumulative_maps, predicted_classes, all_scores, all_attributions = getAttribution(model, input_image, default_thresholds)
            all_attributions = reorganize_attribution_maps(all_attributions)
            visualize_xai(input_image, pfams_list, all_attributions, cumulative_maps, cumulative_map_fpam, "jet",
                          get_class_label(dataset_name, predicted_classes), get_class_label(dataset_name, label))


def distribution_drift(model, model_name, testloader, device, exit_criterion="iees", threshold_range=None,weights=[]):
    bbc_results = evaluate_model_thresholds(model, model_name, testloader, device, exit_criterion="iees", threshold=threshold_range,weights=weights)[0]
    print(bbc_results)
    return {'threshold': bbc_results["threshold"],'accuracy': bbc_results["exit_mean_accuracy"],
            'avg_flops': bbc_results["avg_flops"],'exit_counts': bbc_results["exit_counts"],
            'avg_latencies': bbc_results["avg_latencies"],'exit_accuracies': bbc_results["exit_accuracies"],
    }


def evaluate_model_thresholds(model, model_name, testloader, device, exit_criterion="iees", threshold=None,weights=[]):
    """
    Evaluate the model with threshold-based early exits using either IEES or confidence.

    Args:
        model: Early-exit model.
        model_name: Name for FLOPs lookup.
        testloader: DataLoader for test data.
        device: 'cuda' or 'cpu'.
        exit_criterion: 'iees' or 'confidence'.
        threshold: List or float. If None, generates thresholds.

    Returns:
        List of evaluation metricsper threshold.
    """
    if threshold is None:
        threshold_range = generate_thresholds()
    else:
        threshold_range = threshold if isinstance(threshold, list) else [threshold]

    flops_lookup = getCumulativeFlops(model_name, baseline=False)
    all_results = []

    for threshold_val in threshold_range:
        default_thresholds = [threshold_val] * model.num_exits

        correct_preds = 0
        total_samples = 0
        total_flops = 0

        exit_correct_preds = [0] * (model.num_exits + 1)
        exit_counts = [0] * (model.num_exits + 1)
        exit_latencies = [0.0] * (model.num_exits + 1)

        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)

            for i in range(inputs.size(0)):
                input_image = inputs[i:i + 1]
                label = labels[i].item()
                exit_point = model.num_exits  # fallback
                predicted_class = None
                pfams_list = []
                cumulative_maps = []

                start_time = time.time()

                for exit_idx in range(model.num_exits):
                    logits, class_idx, _ = model.forward_to_exit(input_image, exit_idx,True)
                    predicted_class = class_idx.item()

                    output = F.softmax(logits, dim=1)
                    if exit_criterion == "iees":
                        pfams_list, cumulative_map = generate_pfams(model.activations, model.gradients, input_image,
                                                                    pfams_list)
                        cumulative_maps.append(cumulative_map)
                        consistency_index = calculate_consistency_index(cumulative_maps)
                        iees_score, confidence, gwtedAct, activation_score, gradient_score, n_progressive_score = \
                            compute_iees_score(model.activations, model.gradients, output, consistency_index, weights)
                        exit_score = iees_score

                    elif exit_criterion == "confidence":
                        confidence = F.softmax(output, dim=1).max().item()
                        exit_score = confidence

                    else:
                        raise ValueError("Invalid exit_criterion: choose 'iees' or 'confidence'")

                    if exit_score >= threshold_val:
                        exit_point = exit_idx
                        break

                if exit_point == model.num_exits:
                    output, class_idx, _ = model.forward_to_exit(input_image, model.num_exits)
                    predicted_class = class_idx.item()

                end_time = time.time()
                latency = (end_time - start_time) * 1000  # in milliseconds

                # Metrics collection
                if predicted_class == label:
                    correct_preds += 1
                    exit_correct_preds[exit_point] += 1

                total_flops += flops_lookup[exit_point]
                exit_counts[exit_point] += 1
                exit_latencies[exit_point] += latency
                total_samples += 1

        # Final metrics
        avg_acc = correct_preds / total_samples
        avg_flops = total_flops / total_samples
        avg_latencies = [lat / count if count > 0 else 0.0 for lat, count in zip(exit_latencies, exit_counts)]
        exit_accuracies = [exit_correct_preds[i] / exit_counts[i] if exit_counts[i] > 0 else 0.0
                           for i in range(model.num_exits + 1)]

        weighted_overall_accuracy = sum(
            exit_accuracies[i] * (exit_counts[i] / total_samples) for i in range(model.num_exits + 1)
        )
        exit_mean_accuracy = sum(exit_accuracies) / (model.num_exits + 1)

        all_results.append({
            'threshold': threshold_val,
            'avg_accuracy': avg_acc,
            'avg_flops': avg_flops,
            'exit_counts': exit_counts,
            'avg_latencies': avg_latencies,
            'exit_accuracies': exit_accuracies,
            'weighted_overall_accuracy': weighted_overall_accuracy,
            'exit_mean_accuracy': exit_mean_accuracy
        })

    return all_results
