import sys

import numpy as np
import torch

from torch.nn import functional as F
from torchmetrics import Accuracy, CalibrationError
from utils import create_if_not_exists
import os

import pickle
import matplotlib.pyplot as plt
import textwrap

def _entropy_from_probs(probs: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    # probs: [B, C]
    p = probs.clamp_min(eps)
    return -(p * p.log()).sum(dim=-1)  # [B]

def _to_numpy(x: torch.Tensor) -> np.ndarray:
    return x.detach().float().cpu().numpy()


def plot_mean_variance_diagrams(logits, batch_texts=None, labels=None, save_dir="vis_result", args=None):
    """
    logits: tensor [B, S, C]
    batch_texts: list of decoded text for each sample (optional)
    save_dir: directory to save figures (default: 'vis_result')
    args: arguments containing model name and dataset info
    """

    # Create directory if it does not exist
    os.makedirs(save_dir, exist_ok=True)

    # Convert logits to probabilities
    # Softmax along the last dimension (Class dimension)
    probs = torch.softmax(logits, dim=-1).cpu().numpy()  # [B, S, C]
    B, S, C = probs.shape

    for i in range(B):
        # Calculate statistics across the Sequence dimension (axis=0)
        mean_probs = probs[i].mean(axis=0)  # shape [C]
        var_probs = probs[i].var(axis=0)  # shape [C]
        std_probs = np.sqrt(var_probs)  # Standard deviation

        plt.figure(figsize=(8, 6))  # Slightly larger figure to accommodate wrapped text

        # Plot bar chart with symmetric error bars (Mean +/- Std)
        # Note: This allows error bars to extend below 0 or above 1
        plt.bar(np.arange(C), mean_probs, yerr=std_probs, capsize=5)

        plt.xlabel("Class index", fontsize=12)
        plt.ylabel("Mean probability ± std", fontsize=12)

        # Construct the title
        raw_title = f"Sample {i}"
        if batch_texts is not None:
            # Replace newlines in the text to prevent formatting issues
            text_preview = batch_texts[i].replace("\n", " ")
            raw_title += f" — {text_preview}"
        if labels is not None:
            label = labels[i]
            raw_title += f"-GT Label:{label}"
        # Wrap the title text to avoid cutting it off
        # width=60 means it will wrap after approx 60 characters
        wrapped_title = "\n".join(textwrap.wrap(raw_title, width=60))
        plt.title(wrapped_title, fontsize=10)

        # Adjust layout to make room for the potentially multi-line title
        plt.tight_layout()

        # Save file
        # Ensure args are valid before accessing attributes
        model_name = getattr(args, 'modelwrapper', 'model')
        dataset_name = getattr(args, 'dataset', 'data')
        out_path = os.path.join(save_dir, f"{model_name}_{dataset_name}_sample_{i}.png")

        plt.savefig(out_path, dpi=200)
        plt.close()

        print(f"[Saved] {out_path}")


def append_dictionary(dic1, dic2):
    """
    Extend dictionary dic1 with dic2.
    """
    for key in dic2.keys():
        if key in dic1.keys():
            dic1[key].append(dic2[key])
        else:
            dic1[key] = [dic2[key]]


def accuracy_topk(output, target, k=1):
    """Computes the topk accuracy"""
    batch_size = target.size(0)

    _, pred = torch.topk(output, k=k, dim=1, largest=True, sorted=True)

    res_total = 0
    for curr_k in range(k):
        curr_ind = pred[:, curr_k]
        num_eq = torch.eq(curr_ind, target).sum()
        acc = num_eq / len(output)
        res_total += acc
    return res_total * 100


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def evaluate_all(model, dataloader, accelerator, args, sample=True, num_classes=None):
    """
    Evaluates the **acc, ece, nll** of the model for given dataset.

    Args:
        model: the model to be evaluated
        dataloader: the dataset to evaluate the model on
        kwargs: optional arguments
    Returns:
        acc: accuracy of the model evaluated on the dataloader
    """
    status = model.net.training
    model.net.eval()

    nlls = AverageMeter()
    if num_classes is None:
        num_classes = args.outdim
    metric_kwargs = {"task": "multiclass", "num_classes": num_classes}
    acc_metric = Accuracy(**metric_kwargs).to(accelerator.device)
    ece_metric = CalibrationError(**metric_kwargs, n_bins=args.num_bins).to(accelerator.device)
    briers = AverageMeter()

    samples_seen = 0
    for step, batch in enumerate(dataloader):
        with torch.no_grad() and torch.inference_mode():
            if args.dataset_type == 'mcdataset':
                _, labels, _ = batch
                logits = model(batch, sample=not args.bayes_inference_notsample).detach()
            else:
                logits = model(batch, sample=not args.bayes_inference_notsample).detach()
                labels = batch["labels"]
            logits, labels = accelerator.gather([logits, labels])
            if accelerator.num_processes > 1:
                if step == len(dataloader) - 1:
                    labels = labels[: len(dataloader.dataset) - samples_seen]
                    logits = logits[: len(dataloader.dataset) - samples_seen]
                else:
                    samples_seen += labels.shape[0]
            # loss_func =
            # loss_func = torch.nn.CrossEntropyLoss(reduction="mean")
            # nll = loss_func(logits, labels)
            # nlls.update(nll)

            if (not args.bayes_inference_notsample and (
                    args.modelwrapper.startswith('blob') or args.modelwrapper.startswith(
                    'light'))) or args.model.startswith('deepensemble') or args.model.startswith('mcdropout'):
                probs = torch.softmax(logits, dim=-1).mean(dim=1)
                std = torch.softmax(logits, dim=-1).std(dim=1).mean()
            else:
                probs = torch.softmax(logits, dim=-1)
                std = 0

            acc_metric(probs, labels)
            ece_metric(probs, labels)
            loss_func = torch.nn.NLLLoss(reduction="mean")
            nll = loss_func(torch.log(probs), labels)
            nlls.update(nll)

            brier = (probs - F.one_hot(labels, num_classes=logits.size(-1))).pow(2).sum(dim=-1).mean()
            briers.update(brier)

    acc = acc_metric.compute().item()
    ece = ece_metric.compute().item()
    nll = nlls.avg
    brier = briers.avg
    model.net.train(status)

    return acc, ece, nll, std, brier


def logit_entropy(probs):
    return (-torch.sum(probs * torch.log(probs), dim=1)).cpu().numpy()


def max_softmax(probs):
    return (1 - probs.max(dim=1)[0]).cpu().numpy()


def logit_std(probs):
    return (probs.std(dim=1)).cpu().numpy()


def evaluate_ood_detection(model, dataset, ood_ori_dataset, dataloader, ood_ori_dataloader, accelerator, args, nsamp=1,
                           sample=True, num_classes=None):
    """
    Evaluates the **acc, ece, nll** of the model for given dataset.

    Args:
        model: the model to be evaluated
        dataloader: the dataset to evaluate the model on
        kwargs: optional arguments
    Returns:
        acc: accuracy of the model evaluated on the dataloader
    """
    
    print('model.model.training')
    status = model.model.training
    print(status)

    loss = F.nll_loss

    if args.laplace_ood:
        model.model.load_laplace(target_ids=ood_ori_dataset.target_ids.squeeze(-1), dataset_type=args.dataset_type , args=args, device=accelerator.device)
    model.model.eval()
    model.tokenizer = ood_ori_dataset.tokenizer
    # model.net.module.target_ids = ood_ori_dataset.target_ids.squeeze(-1) # it was model.net.module. ....
    model.model.target_ids = ood_ori_dataset.target_ids.squeeze(-1)  # it was model.net.module.model. ....

    if num_classes is None:
        # num_classes = args.outdim
        num_classes = args.ood_ori_outdim

    nll_ood = AverageMeter()
    metric_kwargs = {"task": "multiclass", "num_classes": num_classes}
    acc_ood = Accuracy(**metric_kwargs).to(accelerator.device)
    ece_ood = CalibrationError(**metric_kwargs, n_bins=args.num_bins).to(accelerator.device)

   
    ns = nsamp  
    samples_seen = 0
    id_prob_list = np.array([])
    
    for step, batch in enumerate(ood_ori_dataloader):

        with torch.no_grad() and torch.inference_mode():


            if ns == 0:
                # print('inside if with ns'+str(ns))
                logits = model.model.forward_logits(batch, sample=False, n_samples=ns).detach()

                # print('logits.shape with ns == 0')
                # print(logits.shape)
            if ns != 0:

                if not hasattr(args, 'bayes_inference_notsample'):
                    args.bayes_inference_notsample = False
                logits = model.model.forward_logits(batch, sample=not args.bayes_inference_notsample,
                                                    n_samples=ns).detach()

               

            if args.dataset_type == 'mcdataset':
                _, labels, _ = batch
            else:
                labels = batch['labels']

            logits, labels = accelerator.gather([logits, labels])
            if accelerator.num_processes > 1:
                if step == len(ood_ori_dataloader) - 1:
                    lables = lables[: len(ood_ori_dataloader.dataset) - samples_seen]
                    logits = logits[: len(ood_ori_dataloader.dataset) - samples_seen]
                else:
                    samples_seen += lables.shape[0]

            if (args.modelwrapper.startswith('blob') or args.modelwrapper.startswith('light')):
            
                if ns != 0:
                    pre_mean_probs = torch.softmax(logits, dim=-1)
                    probs = torch.softmax(logits, dim=-1).mean(dim=1)
                    std = torch.softmax(logits, dim=-1).std(dim=1).mean()
                    logits = logits.mean(dim=1)
                if ns == 0:
                    probs = torch.softmax(logits, dim=-1)
                    std = 0
            else:
               
                if ns != 0:
                    re_mean_probs = torch.softmax(logits, dim=-1)
                    probs = torch.softmax(logits, dim=-1).mean(dim=1)
                    std = torch.softmax(logits, dim=-1).std(dim=1).mean()
                    logits = logits.mean(dim=1)
                if ns == 0:
                    probs = torch.softmax(logits, dim=-1).squeeze(1)
                    std = 0

          

            acc_ood(probs, labels)
            ece_ood(probs, labels)
            nll = loss(torch.log(probs), labels, reduction='mean')
            if torch.isinf(nll):
                nll = loss(torch.log(probs + 1e-6), lables, reduction='mean')
           

            nll_ood.update(nll)


            id_probs = max_softmax(probs)
            id_prob_list = np.append(id_prob_list, id_probs)

    acc_ood_value = acc_ood.compute().item()
    ece_ood_value = ece_ood.compute().item()
    nll_ood_value = nll_ood.avg

    id_label_list = np.zeros_like(id_prob_list)

    ################################################################################
    # OOD Dataset
    ################################################################################
    model.tokenizer = dataset.tokenizer
  
    model.model.target_ids = dataset.target_ids.squeeze(-1) 
   
    

    ood_prob_list = np.array([])
    samples_seen = 0
    nlls = AverageMeter()
   
    num_classes = args.outdim

    metric_kwargs = {"task": "multiclass", "num_classes": num_classes}
    acc_metric = Accuracy(**metric_kwargs).to(accelerator.device)
    ece_metric = CalibrationError(**metric_kwargs, n_bins=args.num_bins).to(accelerator.device)
    briers = AverageMeter()
   
    for step, batch in enumerate(dataloader):
        with torch.no_grad() and torch.inference_mode():

            if ns == 0:
                logits = model.model.forward_logits(batch, sample=False, n_samples=ns).detach()
            if ns != 0:
                logits = model.model.forward_logits(batch, sample=not args.bayes_inference_notsample,
                                                    n_samples=ns).detach()

            if args.dataset_type == 'mcdataset':
                _, labels, _ = batch
            else:
                labels = batch['labels']

            logits, labels = accelerator.gather([logits, labels])
            if accelerator.num_processes > 1:
              
                if step == len(dataloader) - 1:
                    lables = lables[: len(dataloader.dataset) - samples_seen]
                    logits = logits[: len(dataloader.dataset) - samples_seen]
                else:
                    samples_seen += lables.shape[0]

           
            if (args.modelwrapper.startswith('blob') or args.modelwrapper.startswith('light')):
               
                if ns != 0:
                    probs_samples = torch.softmax(logits, dim=-1)  # [B, S, C]
                    
                    probs = probs_samples.mean(dim=1)  # [B, C]
                    std = probs_samples.std(dim=1).mean()
                    logits = logits.mean(dim=1)
                if ns == 0:
                    probs = torch.softmax(logits, dim=-1)
                    std = 0
            else:
                if ns != 0:
                    probs_samples = torch.softmax(logits, dim=-1)  # [B, S, C]
                    probs = probs_samples.mean(dim=1)  # [B, C]
                    std = probs_samples.std(dim=1).mean()
                    logits = logits.mean(dim=1)
                   
                if ns == 0:
                    probs = torch.softmax(logits, dim=-1).squeeze(1)
                    std = 0

             
            acc_metric(probs, labels)
            ece_metric(probs, labels)
            loss_func = torch.nn.NLLLoss(reduction="mean")
            nll = loss_func(torch.log(probs), labels)
            nlls.update(nll)

            brier = (probs - F.one_hot(labels, num_classes=logits.size(-1))).pow(2).sum(dim=-1).mean()
            briers.update(brier)

           
            ood_probs = max_softmax(probs)
            ood_prob_list = np.append(ood_prob_list, ood_probs)


    acc = acc_metric.compute().item()
    ece = ece_metric.compute().item()
    nll = nlls.avg
    brier = briers.avg

    ood_label_list = np.ones_like(ood_prob_list)
    labels = np.concatenate((id_label_list, ood_label_list))
    probs = np.concatenate((id_prob_list, ood_prob_list))

    model.model.train(status)  
    return acc, ece, nll, brier, acc_ood_value, ece_ood_value, nll_ood_value



