import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os.path as osp
import models.network as network


class pretrained_model(nn.Module):
    def __init__(self, args):
        super(pretrained_model, self).__init__()

        # Build the 3 sub-modules
        if args.net[0:3] == 'res':
            self.netF = network.ResBase(res_name=args.net).cuda()
        elif args.net[0:3] == 'vgg':
            self.netF = network.VGGBase(vgg_name=args.net).cuda()
        else:
            raise ValueError(f"Unsupported backbone type: {args.net}")

        self.netB = network.feat_bottleneck(
            type=args.classifier,
            feature_dim=self.netF.in_features,
            bottleneck_dim=args.bottleneck
        ).cuda()

        self.netC = network.feat_classifier(
            type=args.layer,
            class_num=args.class_num,
            bottleneck_dim=args.bottleneck
        ).cuda()

        # Load weights if available
        model_F_path = osp.join(args.modelpath, f"source_F_{args.seed}.pt")
        model_B_path = osp.join(args.modelpath, f"source_B_{args.seed}.pt")
        model_C_path = osp.join(args.modelpath, f"source_C_{args.seed}.pt")

        if osp.exists(model_F_path):
            self.netF.load_state_dict(torch.load(model_F_path))
            print(f"[✓] Loaded feature extractor from {model_F_path}")
        if osp.exists(model_B_path):
            self.netB.load_state_dict(torch.load(model_B_path))
            print(f"[✓] Loaded bottleneck from {model_B_path}")
        if osp.exists(model_C_path):
            self.netC.load_state_dict(torch.load(model_C_path))
            print(f"[✓] Loaded classifier from {model_C_path}")

        self.eval()

    def get_feature_info(self, x):
        bottleneck_output = self.netB(self.netF(x))
        class_logits = self.netC(bottleneck_output)
        probs = nn.Softmax(dim=1)(class_logits)#F.softmax(class_logits, dim=1)

        return bottleneck_output, probs

    def forward(self, x):
        x = self.netF(x)
        x = self.netB(x)
        x = self.netC(x)
        return x

# def load_model(args):
#     if args.net[0:3] == 'res':
#         netF = network.ResBase(res_name=args.net).cuda()
#     netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
#     netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda()
#     model_F_path = osp.join(args.output_dir_src, "source_F_{}.pt".format(args.seed))
#     model_B_path = osp.join(args.output_dir_src, "source_B_{}.pt".format(args.seed))
#     model_C_path = osp.join(args.output_dir_src, "source_C_{}.pt".format(args.seed))

#     if osp.exists(model_F_path):
#         netF.load_state_dict(torch.load(model_F_path))
#         print(f"Loaded feature extractor from {model_F_path}")

#     if osp.exists(model_B_path):
#         netB.load_state_dict(torch.load(model_B_path))
#         print(f"Loaded bottleneck layer from {model_B_path}")

#     if osp.exists(model_C_path):
#         netC.load_state_dict(torch.load(model_C_path))
#         print(f"Loaded classifier from {model_C_path}")

#     # Set to evaluation mode
#     netF.eval()
#     netB.eval()
#     netC.eval()

#     return netF, netB, netC

def load_data(dataset):
    pass

def AURC(y_true, y_pred, score, thresholds=None):
    """
        Computes Accuracy Under Risk Curve (AURC) at different confidence thresholds.

        Parameters:
        - y_true (torch.Tensor): Ground truth labels of shape (n,).
        - y_pred (torch.Tensor): Predicted labels of shape (n,).
        - score (torch.Tensor): Confidence scores for each prediction (n,).
        - thresholds (list of int, optional): List of percentile thresholds (e.g., [10, 20, 50, 90, 100]).

        Returns:
        - accuracy (list of float): List of accuracies computed at each threshold.
    """
    sorted_scores, sorted_indices = torch.sort(score, dim=0)
    sorted_y_true = y_true[sorted_indices]
    sorted_y_pred = y_pred[sorted_indices]

    num_samples = len(score)
    threshold_indices = [min(int((p / 100) * num_samples), num_samples - 1) for p in thresholds]
    accuracy = [0] * len(threshold_indices)
    for i in range(len(thresholds)):
        correct = (sorted_y_true[threshold_indices[i]:] == sorted_y_pred[threshold_indices[i]:]).float()
        accuracy[i] = correct.mean().item()

    return accuracy


def plot_aurc(y_true, y_pred, score_dict, thresholds):
    """
    Plot AURC (or related) curves for multiple scoring methods.

    Args:
        y_true (Tensor or ndarray): Ground truth labels.
        y_pred (Tensor or ndarray): Model predictions.
        score_dict (dict): Dict of {score_name: score_values} to plot.
        thresholds (list or array): Threshold values for x-axis.
    """
    plt.figure(figsize=(8, 6))
    for score_name, score in score_dict.items():
        accuracy = AURC(y_true.cpu(), y_pred.cpu(), score.cpu(), thresholds=thresholds)
        plt.plot(thresholds, accuracy, label=score_name)

    plt.xlabel("Confidence Threshold")
    plt.ylabel("Accuracy")
    plt.title("AURC Curve")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()




def generate_separated_2d_dataset(num_samples_per_class=100, std=1.5, seed=42):
    """
    Generate a 2D tensor dataset of 7 classes, each normally distributed with distinct means.

    Args:
        num_samples_per_class (int): Number of samples per class.
        std (float): Standard deviation for each class's normal distribution.
        seed (int): Random seed for reproducibility.

    Returns:
        data (torch.Tensor): Shape (7 * num_samples_per_class, 2)
        labels (torch.Tensor): Shape (7 * num_samples_per_class,)
    """
    torch.manual_seed(seed)

    # Define 7 distinct means, well separated in 2D
    means = torch.tensor([
        [0.0, 0.0],
        [3.0, 0.0],
        [0.0, 3.0],
        [3.0, 3.0],
    ])

    data = []
    labels = []

    for i, mean in enumerate(means):
        class_data = mean + std * torch.randn(num_samples_per_class, 2)
        data.append(class_data)
        labels.append(torch.full((num_samples_per_class,), i))

    data = torch.cat(data, dim=0)
    labels = torch.cat(labels, dim=0)

    return data, labels


def label_proportions(predicted_label, num_classes=None):
    # Ensure predicted_label is 1D
    predicted_label = predicted_label.view(-1)

    if num_classes is None:
        num_classes = int(predicted_label.max().item()) + 1

    counts = torch.bincount(predicted_label, minlength=num_classes)
    total = counts.sum().float()
    weights = counts.float() / total

    return weights


def compute_accuracy(pred, target):
    correct = (pred == target).sum().item()
    total = target.numel()
    return correct / total


# Example usage and visualization
if __name__ == "__main__":
    data, labels = generate_separated_2d_dataset()

    # Plot the dataset
    plt.figure(figsize=(8, 6))
    for i in range(7):
        class_data = data[labels == i]
        plt.scatter(class_data[:, 0], class_data[:, 1], label=f'Class {i}', alpha=0.6)
    plt.legend()
    plt.title('2D Gaussian Clusters for 7 Classes')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.grid(True)
    plt.show()
