import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
from domainbed.lib.fast_data_loader import FastDataLoader
from domainbed import algorithms
from domainbed.uncertainty import get_model_uncertainty_values, \
    get_calibration_errors, calc_tot_unc_auc_roc, calc_pn_know_unc_auc_roc

import numpy as np

# from uncertainty_metrics.numpy.general_calibration_error import ece
# import matplotlib.probsplot as plt
# plt.style.use('ggplot')
EPS = 1e-10

def accuracy_from_loader(algorithm, loader, weights, device, debug=False):
    correct = 0
    total = 0
    losssum = 0.0
    weights_offset = 0
    all_probs = []
    all_labels = []
    all_logits = []
    all_preds = []

    algorithm.eval()

    for i, batch in enumerate(loader):
        x = batch["x"].to(device)
        y = batch["y"].to(device)

        with torch.no_grad():
            logits = algorithm.predict(x)
            probs = torch.nn.functional.softmax(logits, dim=1)
            confidences, preds = probs.max(dim=1)
            # breakpoint()
            loss = F.cross_entropy(logits, y).item()
            
        all_logits.append(logits.cpu().numpy())
        all_probs.append(probs.cpu().numpy())
        all_labels.append(y.cpu().numpy())  
        all_preds.append(preds.cpu().numpy())
              

        B = len(x)
        losssum += loss * B

        if weights is None:
            batch_weights = torch.ones(len(x))
        else:
            batch_weights = weights[weights_offset : weights_offset + len(x)]
            weights_offset += len(x)
        batch_weights = batch_weights.to(device)
        if logits.size(1) == 1:
            correct += (logits.gt(0).eq(y).float() * batch_weights).sum().item()
        else:
            correct += (logits.argmax(1).eq(y).float() * batch_weights).sum().item()
        total += batch_weights.sum().item()
        
        if debug:
            break

    algorithm.train()
    
    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    all_logits = np.concatenate(all_logits, axis=0)
    all_preds = np.concatenate(all_preds, axis=0)

    acc = correct / total
    domain_hits = (all_preds == all_labels)
    uncertainties = get_model_uncertainty_values(all_logits)
    ece = get_calibration_errors(all_probs, domain_hits, all_labels, num_bins=10)
    
    uncertainties.update(ece)


    # print(total,"!!!!!!!")
    return acc, loss, uncertainties

'''def accuracy_from_loader(algorithm, loader, weights, device, debug=False):
    correct = 0
    total = 0
    losssum = 0.0
    weights_offset = 0

    algorithm.eval()

    for i, batch in enumerate(loader):
        x = batch["x"].to(device)
        y = batch["y"].to(device)

        with torch.no_grad():
            logits = algorithm.predict(x)
            loss = F.cross_entropy(logits, y).item()

        B = len(x)
        losssum += loss * B

        if weights is None:
            batch_weights = torch.ones(len(x))
        else:
            batch_weights = weights[weights_offset : weights_offset + len(x)]
            weights_offset += len(x)
        batch_weights = batch_weights.to(device)
        if logits.size(1) == 1:
            correct += (logits.gt(0).eq(y).float() * batch_weights).sum().item()
        else:
            correct += (logits.argmax(1).eq(y).float() * batch_weights).sum().item()
        total += batch_weights.sum().item()

        if debug:
            break

    algorithm.train()

    acc = correct / total
    loss = losssum / total
    # print(total,acc,"!!!!!!!")
    return acc, loss  
    '''

def accuracy(algorithm, loader_kwargs, weights, device, **kwargs):
    if isinstance(loader_kwargs, dict):
        loader = FastDataLoader(**loader_kwargs)
    elif isinstance(loader_kwargs, FastDataLoader):
        loader = loader_kwargs
    else:
        raise ValueError(loader_kwargs)
    return accuracy_from_loader(algorithm, loader, weights, device, **kwargs)

def load_logits(algorithm, loader_kwargs, device):
    logits = []
    # algorithm.cpu()
    algorithm.to(device)
    algorithm.eval()
    
    if isinstance(loader_kwargs, dict):
        loader = FastDataLoader(**loader_kwargs)
    elif isinstance(loader_kwargs, FastDataLoader):
        loader = loader_kwargs
    else:
        raise ValueError(loader_kwargs)
    
    with torch.no_grad():
        for i, batch in enumerate(loader):
            x = batch["x"].to(device)
            logit = algorithm.predict(x)
            logits.append(logit.cpu())
            # breakpoint()
    return np.concatenate([logit.numpy() for logit in logits], axis=0)
   
class Evaluator:
    def __init__(
        self, test_envs, eval_meta, n_envs, logger, device, evalmode="fast", debug=False, target_env=None
    ):
        all_envs = list(range(n_envs))
        train_envs = sorted(set(all_envs) - set(test_envs))
        self.test_envs = test_envs
        self.train_envs = train_envs
        self.eval_meta = eval_meta
        self.n_envs = n_envs
        self.logger = logger
        self.evalmode = evalmode
        self.debug = debug
        self.device = device

        if target_env is not None:
            self.set_target_env(target_env)

    def set_target_env(self, target_env):
        """When len(test_envs) == 2, you can specify target env for computing exact test acc."""
        self.test_envs = [target_env]

    def evaluate(self, algorithm, ret_losses=False):
        n_train_envs = len(self.train_envs)
        n_test_envs = len(self.test_envs)
        assert n_test_envs == 1
        summaries = collections.defaultdict(float)
        # for key order
        summaries["test_in"] = 0.0
        summaries["test_out"] = 0.0
        summaries["train_in"] = 0.0
        summaries["train_out"] = 0.0
        summaries["EoE"] = 0.0
        accuracies = {}
        losses = {}
        uncertainties = {}
        in_logits = []
        out_logits = []
        

        # order: in_splits + out_splits.
        for name, loader_kwargs, weights in self.eval_meta:
            env_name, inout = name.split("_")
            env_num = int(env_name[3:]) 
            is_test = env_num in self.test_envs
            if inout == 'out':
                if not is_test:
                    in_logit = load_logits(algorithm, loader_kwargs, self.device)
                    in_logits.append(in_logit)
                    # print("Loaded in_logits:", len(in_logit), "Last shape:", in_logit.shape)
            if is_test:
                out_logit = load_logits(algorithm, loader_kwargs, self.device)
                out_logits.append(out_logit)
                # print("Loaded out_logits:", len(out_logit), "Last shape:", out_logit.shape)

        if in_logits:
            in_logits = np.concatenate(in_logits, axis=0)
        if out_logits:
            out_logits = np.concatenate(out_logits, axis=0)
        # print("in_logits shape:", in_logits.shape)
        # print("out_logits shape:", out_logits.shape)
        # breakpoint()
        Tot_auc = calc_tot_unc_auc_roc(in_logits, out_logits)
        know_auc = calc_pn_know_unc_auc_roc(in_logits, out_logits)
        summaries["Tauc"] = Tot_auc
        summaries["Kauc"] = know_auc
        
        # # Safe AUC calculation
        # if len(in_logits) > 0:
        #     in_logits = np.concatenate(in_logits, axis=0)
        # else:
        #     in_logits = None

        # if len(out_logits) > 0:
        #     out_logits = np.concatenate(out_logits, axis=0)
        # else:
        #     out_logits = None

        # if in_logits is not None and out_logits is not None:
        #     try:
        #         Tot_auc = calc_tot_unc_auc_roc(in_logits, out_logits)
        #         know_auc = calc_pn_know_unc_auc_roc(in_logits, out_logits)
        #     except Exception as e:
        #         print(f"Warning: Skipping AUC calc due to error: {e}")
        #         Tot_auc = 0.0
        #         know_auc = 0.0
        # else:
        #     Tot_auc = 0.0
        #     know_auc = 0.0
                    
                      
                     
        for name, loader_kwargs, weights in self.eval_meta:
            # env\d_[in|out]
            env_name, inout = name.split("_")
            env_num = int(env_name[3:])
            # print(name, loader_kwargs)

            skip_eval = self.evalmode == "fast" and inout == "in" and env_num not in self.test_envs
            if skip_eval:
                continue

            is_test = env_num in self.test_envs
            acc, loss, uncertainty= accuracy(algorithm, loader_kwargs, weights, device = self.device, debug=self.debug)

            accuracies[name] = acc
            losses[name] = loss
            uncertainties[name] = uncertainty
            # breakpoint()

            # breakpoint()

            if env_num in self.train_envs:
                summaries["train_" + inout] += acc / n_train_envs

                if inout == "out":
                    summaries["tr_" + inout + "loss"] += loss / n_train_envs
            elif is_test:
                summaries["test_" + inout] += acc / n_test_envs
                if inout == "in" and env_num in self.test_envs:
                    summaries["EoE"] = uncertainties[name]['entropy_of_expected']
                    summaries["ExE"] = uncertainties[name]['expected_entropy']
                    summaries["ML"] = uncertainties[name]['mutual_information']
                    # summaries["conf"] = uncertainties[name]['confidence']
                    # summaries["ECE"] = uncertainties[name]['ECE']
                    # summaries["NLL"] = uncertainties[name]['NLL']
                    # print(env_num,name)


        if ret_losses:
            return accuracies, summaries, losses
        else:
            return accuracies, summaries, uncertainties

'''
class Evaluator:
    def __init__(
        self, test_envs, eval_meta, n_envs, logger, device, evalmode="fast", debug=False, target_env=None
    ):
        all_envs = list(range(n_envs))
        train_envs = sorted(set(all_envs) - set(test_envs))
        self.test_envs = test_envs
        self.train_envs = train_envs
        self.eval_meta = eval_meta
        self.n_envs = n_envs
        self.logger = logger
        self.evalmode = evalmode
        self.debug = debug
        self.device = device

        if target_env is not None:
            self.set_target_env(target_env)

    def set_target_env(self, target_env):
        """When len(test_envs) == 2, you can specify target env for computing exact test acc."""
        self.test_envs = [target_env]

    def evaluate(self, algorithm, ret_losses=False):
        n_train_envs = len(self.train_envs)
        n_test_envs = len(self.test_envs)
        assert n_test_envs == 1
        summaries = collections.defaultdict(float)
        # for key order
        summaries["test_in"] = 0.0
        summaries["test_out"] = 0.0
        summaries["train_in"] = 0.0
        summaries["train_out"] = 0.0
        accuracies = {}
        losses = {}

        # order: in_splits + out_splits.
        for name, loader_kwargs, weights in self.eval_meta:
            # env\d_[in|out]
            env_name, inout = name.split("_")
            env_num = int(env_name[3:])
            # print(name, loader_kwargs)

            skip_eval = self.evalmode == "fast" and inout == "in" and env_num not in self.test_envs
            if skip_eval:
                continue

            is_test = env_num in self.test_envs
            acc, loss = accuracy(algorithm, loader_kwargs, weights, device = self.device, debug=self.debug)
            accuracies[name] = acc
            losses[name] = loss

            if env_num in self.train_envs:
                summaries["train_" + inout] += acc / n_train_envs
                if inout == "out":
                    summaries["tr_" + inout + "loss"] += loss / n_train_envs
            elif is_test:
                summaries["test_" + inout] += acc / n_test_envs

        if ret_losses:
            return accuracies, summaries, losses
        else:
            return accuracies, summaries
'''