from .common import (
    DataClass,
    LoraBatchDataConfig,
    MultiLoraBatchData,
    MixConfig,
)
from .tasks import (
    BasicTask,
    BasicMetric,
    CommonSenseTask,
    task_dict,
)
from .tokenizer import Tokenizer
from .model import LLMModel


from dataclasses import dataclass
from typing import List, Dict

import logging
import torch
import math
import json
import time
import numpy as np

@dataclass
class EvaluateConfig:
    adapter_name: str = None
    task_name: str = None
    batch_size: int = 16
    router_profile: bool = False
    # Do not set these manually
    task_: BasicTask = None
    data_: List[DataClass] = None
    metric_: BasicMetric = None
    rollback_start_idx_: int = 0
    batch_start_idx_: int = 0
    batch_end_idx_: int = 0

    def prepare(self, tokenizer: Tokenizer, device: str):
        self.task_ = task_dict[self.task_name]
        self.data_ = self.task_.loading_data(tokenizer, False)
        self.metric_ = self.task_.loading_metric()
        if isinstance(self.task_, CommonSenseTask):
            labels = self.task_.label_list()
            label_indices = [0] * len(labels)
            for idx, label in enumerate(labels):
                ids = tokenizer.encode(" " + label)
                label_indices[idx] = ids[-1]
            self.label_indices_ = torch.tensor(
                label_indices, dtype=torch.int64, device=device)
        else:
            self.label_indices_ = None

# source: https://towardsdatascience.com/expected-calibration-error-ece-a-step-by-step-visual-explanation-with-python-code-c3e9aa12937d
def ece(true_labels, predicted_label, confidences, n_bins=15):
    true_labels = true_labels.reshape(-1)
    # uniform binning approach with M number of bins
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    # get a boolean list of correct/false predictions
    accuracies = predicted_label==true_labels
    # print(f'predicted_label: {predicted_label.shape}, true_labels: {true_labels.shape}', flush=True)
    # print(f'confidences: {confidences.shape}, accuracies: {accuracies.shape}', flush=True)
    correct_confidence = float(np.mean(confidences[accuracies]))
    incorrect_confidence = float(np.mean(confidences[~accuracies]))
    ece = 0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # determine if sample is in bin m (between bin lower & upper)
        in_bin = np.logical_and(confidences > bin_lower.item(), confidences <= bin_upper.item())
        # can calculate the empirical probability of a sample falling into bin m: (|Bm|/n)
        prob_in_bin = in_bin.mean()

        if prob_in_bin.item() > 0:
            # get the accuracy of bin m: acc(Bm)
            accuracy_in_bin = accuracies[in_bin].mean()
            # get the average confidence of bin m: conf(Bm)
            avg_confidence_in_bin = confidences[in_bin].mean()
            # calculate |acc(Bm) - conf(Bm)| * (|Bm|/n) for bin m and add to the total ECE
            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prob_in_bin
    return ece, correct_confidence, incorrect_confidence, accuracies

# # takes in numpy arrays
# def ece(y_ground, y_pred, y_conf, n_bins=10):
#     print(f'y_ground: {y_ground.shape}, y_pred: {y_pred.shape}')
#     print(f'y_ground: {y_ground}, y_pred: {y_pred}')
#     # group into bins based on confidence, stores indicies
#     bins = [[] for i in range(n_bins)]
#     # go through each confidence interval and arrange elements
#     for i in range(n_bins):
#         minim = i * (1.0/n_bins)
#         maxim = ((i+1) * (1.0/n_bins))
#         indices = np.where(np.logical_and(y_conf >= minim, y_conf < maxim))
#         bins[i].append(indices)
#     # for each bin, calculate ece
#     total_ece = 0
#     for inds in bins:
#         if (len(inds[0][0]) == 0):
#             continue
#         conf = np.mean(y_conf[inds])
# #         acc = np.mean(np.equal(y_ground[inds], y_pred[inds]))
#         # print(f'y_ground[inds]: {y_ground[inds]}, y_pred[inds]: {y_pred[inds]}')
#         acc = np.mean(y_ground[inds] == y_pred[inds])
#         print(f'bins: {bins}, inds: {inds}')
#         print(f'per-bin (size: {len(inds)}) | confidence: {conf}, accuracy: {acc}')
#         bin_score = (float(len(inds))/y_pred.shape[0]) * np.abs(acc - conf)
#         total_ece = total_ece + bin_score
#     return total_ece

def _prepare_tasks(model, tokenizer, configs):
    for config in configs:
        config.prepare(tokenizer, model.device_)
        if not isinstance(model.adapter_configs_[config.adapter_name], MixConfig):
            continue
        for layer in model.model_.layers_:
            layer.mlp_.moes_[
                config.adapter_name].router_profile_ = config.router_profile


def _dispatch_task_in(tokenizer, configs, concurrent_jobs, max_seq_len):
    batch_data_config = []
    sequence_lengths = []
    current_configs = []
    batch_tokens = []
    batch_labels = []
    atten_masks = []
    max_tokens_len = 0
    for config in configs:
        if len(current_configs) >= concurrent_jobs:
            break
        if config.batch_start_idx_ >= len(config.data_):
            continue
        config.batch_end_idx_ = min(
            config.batch_start_idx_ + config.batch_size, len(config.data_))
        batch_start_idx = len(batch_tokens)
        for idx in range(config.batch_start_idx_, config.batch_end_idx_):
            if idx >= len(config.data_):
                break
            tokens = config.data_[idx].tokens_
            labels = config.data_[idx].labels_
            if len(tokens) > max_seq_len:
                tokens = tokens[:max_seq_len]
            max_tokens_len = max(len(tokens), max_tokens_len)
            # sequence_lengths.append(len(tokens))
            # while len(tokens) < max_seq_len:
            #     tokens.append(tokenizer.pad_id_)
            batch_tokens.append(tokens)
            # atten_masks.append(tokenizer.mask_from(tokens))
            batch_labels.append(labels.copy())

        config.batch_start_idx_ = config.batch_end_idx_
        current_configs.append(config)
        batch_data_config.append(LoraBatchDataConfig(adapter_name_=config.adapter_name,
                                                     batch_start_idx_=batch_start_idx, batch_end_idx_=len(batch_tokens)))

    if max_tokens_len < max_seq_len:
        max_seq_len = math.ceil(max_tokens_len / 8) * 8

    for tokens in batch_tokens:
        sequence_lengths.append(len(tokens) - 1)
        while len(tokens) < max_seq_len:
            tokens.append(tokenizer.pad_id_)
        atten_masks.append(tokenizer.mask_from(tokens))

    return (current_configs,
            sequence_lengths,
            batch_labels,
            MultiLoraBatchData(
                lora_batch_data_config_=batch_data_config,
                batch_tokens_=batch_tokens,
                attention_masks_=atten_masks,
                inference_mode_=True))


def _compute_metrcis(model, current_configs, sequence_lengths, batch_labels, outputs):
    # for ECE calculation
    preds_list = {'prediction':[],'label':[], 'pred_probs': []}
    for idx, output in enumerate(outputs):
        config: EvaluateConfig = current_configs[idx]
        task: BasicTask = config.task_
        metric: BasicMetric = config.metric_
        start_idx = output.batch_start_idx_
        end_idx = output.batch_end_idx_
        logits = output.logits

        if config.router_profile:
            adapter_config = model.adapter_configs_[
                config.adapter_name]
            if isinstance(adapter_config, MixConfig):
                router_statistic_ = list(
                    0 for _ in range(adapter_config.num_experts_))
                for layer in model.model_.layers_:
                    for idx, val in enumerate(layer.mlp_.moes_[config.adapter_name].profiler_):
                        router_statistic_[idx] += val
                for idx, val in enumerate(router_statistic_):
                    logging.info(
                        f"{config.adapter_name}: expert {idx}, load = {val/32}")

        batch_size = logits.shape[0]
        pooled_logits = logits[torch.arange(
            batch_size, device=logits.device), sequence_lengths[start_idx:end_idx]]
        labels = torch.tensor(batch_labels[start_idx:end_idx],
                              dtype=task.label_dtype_, device=logits.device)
        if task.task_type_ == "common_sense":
            # print(f'shape of pooled logits before cropping: {pooled_logits.shape}')
            # print(f'seq_lengths: {sequence_lengths[start_idx:end_idx]}')
            pooled_logits = pooled_logits[:, config.label_indices_]
            # print(f'shape of pooled logits: {pooled_logits.shape}')
            pooled_probs = pooled_logits.softmax(-1).max(dim=-1)[0]
            pooled_logits = pooled_logits.softmax(-1).argmax(-1)
            # pooled_logits = pooled_logits_probs.argmax(-1)
            # preds_list['output_label_confidence'].append(pooled_logits_probs[:,pooled_logits].detach().unsqueeze(-1).cpu().numpy())
        elif task.task_type_ == "single_label_classification":
            pooled_logits = pooled_logits.softmax(-1).argmax(-1)
            # pooled_logits = pooled_logits_probs.argmax(-1)
            # preds_list['output_label_confidence'].append(pooled_logits_probs[:,pooled_logits].detach().unsqueeze(-1).cpu().numpy())
            pooled_logits = pooled_logits.to(task.label_dtype_)
        elif task.task_type_ != "multi_label_classification":
            raise ValueError(f"unknown task type {task.task_type_}")

        predictions = pooled_logits.detach().cpu()
        references = labels.detach().cpu()

        metric.add_batch(predictions=predictions,
                         references=references)

        # save output & labels for ECE calculation
        preds_list['prediction'].append(predictions.numpy())
        preds_list['pred_probs'].append(pooled_probs.cpu().numpy())
        preds_list['label'].append(references.numpy())
        # Confidence for UQ
        # aggregate along the sequence dimension
        # print(f'shape of confidence before mean: {output.var_exp.detach().cpu()[start_idx:end_idx].shape}')
        # print(f'shape of labels: {labels.shape}')
        # selected_inds = sequence_lengths[start_idx:end_idx]
        # print(f'selected_inds: {selected_inds}, {np.array(selected_inds).shape}')
        # TODO: This is seqnece avg, change back if needed
        # confidence = output.var_exp.detach().cpu()[start_idx:end_idx].mean(dim=1)
        # batch_indices = torch.arange(batch_size)
        # confidence = output.var_exp.detach().cpu()[start_idx:end_idx][batch_indices, selected_inds]
        # preds_list['confidence'].append(confidence.numpy())

        logging.info(
            f"{config.adapter_name}, {config.task_name}")
        logging.info(
            f"    step: {config.batch_start_idx_}/{len(config.data_)}")

    return preds_list

def _compute_result(model, configs, save_file, ece_res_output_only, correct_confidence_output_only, incorrect_confidence_output_only):
    results = []
    for config in configs:
        result = {
            "adapter_name": config.adapter_name,
            "task_name": config.task_name,
            "date_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
            "metrics": {},
            "ece_res_output_only": ece_res_output_only,
            "correct_confidence_output_only": correct_confidence_output_only,
            "incorrect_confidence_output_only": incorrect_confidence_output_only,
        }
        compute_results = config.metric_.compute()
        result["metrics"] = compute_results
        if config.router_profile:
            adapter_config = model.adapter_configs_[config.adapter_name]
            if isinstance(adapter_config, MixConfig):
                router_statistic_ = list(
                    0 for _ in range(adapter_config.num_experts_))
                for layer in model.model_.layers_:
                    for idx, val in enumerate(layer.mlp_.moes_[config.adapter_name].profiler_):
                        router_statistic_[idx] += val
                    layer.mlp_.moes_[config.adapter_name].profiler_ = None
                result["router_profile"] = list(
                    val / 32 for val in router_statistic_)

        results.append(result)

    if save_file is not None:
        with open(save_file, "w") as f:
            json.dump(results, f, indent=4)
        logging.info(f"evaluation result: {results}")
        logging.info(f"saving evaluation result to {save_file}")
    else:
        logging.info(json.dumps(results, indent=4))

    return results


@torch.inference_mode()
def evaluate(model: LLMModel,
             tokenizer: Tokenizer,
             configs: List[EvaluateConfig],
             max_concurrent_jobs: int = None,
             retrying_steps: int = 20,
             max_seq_len: int = 512,
             save_file: str = None) -> Dict:

    if max_concurrent_jobs is None:
        max_concurrent_jobs = len(configs)
        logging.info(
            f"Setting max_concurrent_jobs to {max_concurrent_jobs} automatically")

    assert max_concurrent_jobs > 0
    assert retrying_steps > 0

    _prepare_tasks(model, tokenizer, configs)

    concurrent_jobs = max_concurrent_jobs
    retrying_count = 0
    preds_list = {'prediction':[],'label':[],'pred_probs': []}
    while True:
        if concurrent_jobs < max_concurrent_jobs and retrying_count > 0:
            retrying_count -= 1
            if retrying_count == 0:
                concurrent_jobs += 1
                logging.info(
                    f"recovering concurrent jobs to {concurrent_jobs}")

        current_configs, sequence_lengths, batch_labels, input_args = _dispatch_task_in(
            tokenizer, configs, concurrent_jobs, max_seq_len)

        if len(current_configs) == 0:
            break

        try:
            # print(f'in: evaluator: shape of input_args.batch_tokens_: {[len(tk) for tk in input_args.batch_tokens_]}')
            batch_preds_list = _compute_metrcis(model, current_configs,
                             sequence_lengths, batch_labels,
                             model.forward(input_args))
            for key in batch_preds_list.keys():
                preds_list[key].append(np.concatenate(batch_preds_list[key]))

        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                concurrent_jobs -= 1
                if concurrent_jobs == 0:
                    raise e
                logging.warn(
                    f"deprecating concurrent jobs to {concurrent_jobs} due to OOM.")
                # rollback
                retrying_count = retrying_steps
                for config in current_configs:
                    config.batch_start_idx_ = config.rollback_start_idx_
                    logging.info(f"{config.adapter_name}, {config.task_name}: " +
                                 f"rollback to {config.batch_start_idx_}/{len(config.data_)}")
                continue
            else:
                raise e

        for config in current_configs:
            config.rollback_start_idx_ = config.batch_start_idx_

    logging.info(
        f'Evaluation Finished, Calculating ECE...')
    # concatenate the results
    preds_list_prediction = np.concatenate(preds_list['prediction'])
    preds_list_label = np.concatenate(preds_list['label'])
    # TODO: switch back if not using pred probs
    # preds_list_confidence = np.concatenate(preds_list['confidence'])
    pred_probs_list = np.concatenate(preds_list['pred_probs'])
    # preds_list_output_label_confidence = np.concatenate(preds_list['output_label_confidence'])
    # print(f'preds_list[label]: {preds_list_label.shape}, preds_list[prediction]: {preds_list_prediction.shape}, preds_list[confidence]: {preds_list_confidence.shape}')
    # y_true, y_pred, confidence
    # ece_res, correct_confidence, incorrect_confidence = ece(preds_list_label, preds_list_prediction, preds_list_confidence)
    ece_res_output_only, correct_confidence_output_only, incorrect_confidence_output_only, pred_match = ece(preds_list_label, preds_list_prediction, pred_probs_list)
    # np.save(f'confidences.npy', pred_probs_list)
    # np.save(f'correctness.npy', pred_match)
    # np.save(f'predictions.npy', preds_list_prediction)
    # TODO: add label confidence
    # label_conf_res = ece(preds_list_label, preds_list_prediction, preds_list_output_label_confidence)
    return _compute_result(model, configs, save_file, ece_res_output_only, correct_confidence_output_only, incorrect_confidence_output_only)