import torch
from .train import _choose_loss
import numpy as np
from challenge.neurobench import StorkModel, StorkBenchmark
from challenge.test_plot import plot_predictions_vs_labels
import stork
from scipy.stats import pearsonr
import matplotlib.pyplot as plt



def get_test_loss(cfg): #选择损失的评估指标：MSE、RMSE、MAE等

    loss_class = _choose_loss(
        cfg,
    )
    return loss_class()


def _configure_model_eval(model, test_dat, cfg): #配置模型用于评估

    model.set_nb_steps(test_dat[0][0].shape[0]) #设置模型的步数，通常是根据测试数据的形状来确定的。
    model.loss_stack = get_test_loss(cfg)

    return model

def evaluate_model(model, cfg, test_dat): #评估模型

    # Re-configure model for eval 重新配置模型以进行评估。
    model = _configure_model_eval(model, test_dat, cfg)

    # Evaluate model
    scores, pred, _ = model.evaluate_continuos_testdata(test_dat)

    target = test_dat[0][1].numpy()

    SST = np.sum((target - np.mean(target, 0)) ** 2, 0)
    SSR = np.sum((target - pred) ** 2, 0)
    R2 = 1 - SSR / SST

    bm = {}
    if R2.shape == 1:
        bm["R2 X (JR)"] = R2[0].astype(float)
    else:
        bm["R2 X (JR)"] = R2[0].astype(float)
        bm["R2 Y (JR)"] = R2[1].astype(float)
    bm["R2 mean (JR)"] = np.mean(R2).astype(float)

    print(bm)

    # 绘制pred和target的2D对比图
    # plot_predictions_vs_labels(pred, target)

    # BENCHMARK MODEL
    # # # # # # # # #
    # convert to stork model
    test_model = StorkModel(model)

    # define metrics to benchmark

    # Benchmark expects the following dataloader
    test_set_loader = torch.utils.data.DataLoader(
        test_dat,
        batch_size=1,
        shuffle=False,
    )

    benchmark = StorkBenchmark(
        test_model,
        test_set_loader,
        [],
        [],
        [cfg.evaluation.static_metrics, cfg.evaluation.workload_metrics],
    )
    bm_results = benchmark.run()
    bm.update(bm_results)

    return model, scores, pred, bm, target


def benchmark_model(model, cfg, test_dat):

    # Re-configure model for eval
    model = _configure_model_eval(model, test_dat, cfg)
    
    test_model = StorkModel(model)

    # Benchmark expects the following dataloader
    test_set_loader = torch.utils.data.DataLoader(
        test_dat,
        batch_size=1,
        shuffle=False,
    )

    benchmark = StorkBenchmark(
        test_model,
        test_set_loader,
        [],
        [],
        [cfg.evaluation.static_metrics, cfg.evaluation.workload_metrics],
    )
    
    bm_results = benchmark.run()
    
    return bm_results

def evaluate_with_traindata(model, train_dat):
    result=model.evaluate(train_dat)
    print(result)


def evaluate_with_testdata(model, cfg, test_dat, for_draw=False):
    if cfg.evaluation.rep:
        model.rep()

    if isinstance(test_dat, stork.datasets.RasDataset):
        print("test_dat is a RasDataset.")
        model, scores_average, pred, bm, target = evaluate_model(model, cfg, test_dat)
    else:
        print("test_dat is a List.")
        scores=[]
        pred=[]
        bm_tmp=[]
        target = []
        # 把bm的一些参数单独提取出来
        R2X=[]
        R2Y=[]
        R2Mean=[]
        cc=[]
        mse=[]
        rmse=[]
        activation_sparsity=[]
        synaptic_operation=[]
        # 把synaptic_operations一些参数单独提取出来
        Effective_MACs=[]
        Effective_ACs=[]
        Dense=[]

        #
        evaluate_result=[]


        for trialIndex in range(len(test_dat)):

            # evaluate_model获得单trial的评估结果
            model_once, scores_once, pred_once, bm_once, target_once = evaluate_model(model, cfg, test_dat[trialIndex])

            evaluate_result.append({})
            evaluate_result[trialIndex]["R2 X (JR)"] = bm_once['R2 X (JR)']
            evaluate_result[trialIndex]["R2 Y (JR)"] = bm_once['R2 Y (JR)']
            evaluate_result[trialIndex]["R2 mean (JR)"] = bm_once['R2 mean (JR)']
            evaluate_result[trialIndex]["cc"] = bm_once['cc']
            evaluate_result[trialIndex]["mse"] = bm_once['mse']
            evaluate_result[trialIndex]["rmse"] = bm_once['rmse']
            evaluate_result[trialIndex]["activation_sparsity"] = bm_once['activation_sparsity']
            if "synaptic_operations" in bm_once:
                evaluate_result[trialIndex]["synaptic_operations"] = bm_once['synaptic_operations']

            evaluate_result[trialIndex]["pred"] = pred_once
            evaluate_result[trialIndex]["target"] = target_once

            scores.append(scores_once)
            pred.append(pred_once)
            bm_tmp.append(bm_once)
            target.append(target_once)

            # 把bm一些参数单独提取出来
            R2X.append(bm_once['R2 X (JR)'])
            R2Y.append(bm_once['R2 Y (JR)'])
            R2Mean.append(bm_once['R2 mean (JR)'])
            cc.append(bm_once['cc'])
            mse.append(bm_once['mse'])
            rmse.append(bm_once['rmse'])
            activation_sparsity.append(bm_once['activation_sparsity'])
            if "synaptic_operations" in bm_once:
                synaptic_operation.append(bm_once['synaptic_operations'])
                # 把synaptic_operations一些参数单独提取出来
                Effective_MACs.append(bm_once['synaptic_operations']['Effective_MACs'])
                Effective_ACs.append(bm_once['synaptic_operations']['Effective_ACs'])
                Dense.append(bm_once['synaptic_operations']['Dense'])

        evaluate_result.append({})
        evaluate_result[-1]['R2X_average'] = sum(R2X) / len(R2X)
        evaluate_result[-1]['R2Y_average'] = sum(R2Y) / len(R2Y)
        evaluate_result[-1]['R2Mean_average'] = sum(R2Mean) / len(R2Mean)
        evaluate_result[-1]['cc_average'] = sum(cc) / len(cc)
        evaluate_result[-1]['mse_average'] = sum(mse) / len(mse)
        evaluate_result[-1]['rmse_average'] = sum(rmse) / len(rmse)
        evaluate_result[-1]['activation_sparsity_average'] = sum(activation_sparsity) / len(activation_sparsity)
        if "synaptic_operations" in bm_once:
            evaluate_result[-1]['Effective_MACs_average'] = sum(Effective_MACs) / len(Effective_MACs)
            evaluate_result[-1]['Effective_ACs_average'] = sum(Effective_ACs) / len(Effective_ACs)
            evaluate_result[-1]['Dense_average'] = sum(Dense) / len(Dense)


        # 将数据首尾相连
        target_continus = np.concatenate(target)
        pred_continus = np.concatenate(pred)
        R2X_continus, R2Y_continus, R2Mean_continus = r2_get(target_continus,pred_continus)
        cc_continus = cc_get(target_continus,pred_continus)
        mse_continus = mse_get(target_continus,pred_continus)
        rmse_continus = rmse_get(target_continus, pred_continus)

        evaluate_result.append({})
        evaluate_result[-1]["R2 X (JR)"] = R2X_continus
        evaluate_result[-1]["R2 Y (JR)"] = R2Y_continus
        evaluate_result[-1]["R2 mean (JR)"] = R2Mean_continus
        evaluate_result[-1]["footprint"] = bm_tmp[1]['footprint']
        evaluate_result[-1]["connection_sparsity"] = bm_tmp[1]['connection_sparsity']
        evaluate_result[-1]["cc"] = cc_continus
        evaluate_result[-1]["mse"] = mse_continus
        evaluate_result[-1]["rmse"] = rmse_continus
        evaluate_result[-1]["activation_sparsity"] = evaluate_result[-2]['activation_sparsity_average']
        evaluate_result[-1]["r2"] = R2Mean_continus
        if "synaptic_operations" in bm_once:
            evaluate_result[-1]["synaptic_operations"]={}
            evaluate_result[-1]["synaptic_operations"]['Effective_MACs'] = evaluate_result[-2]['Effective_MACs_average']
            evaluate_result[-1]["synaptic_operations"]['Effective_ACs'] = evaluate_result[-2]['Effective_ACs_average']
            evaluate_result[-1]["synaptic_operations"]['Dense'] = evaluate_result[-2]['Dense_average']

        if for_draw:
            bm=evaluate_result
        else:
            bm=evaluate_result[-1]

        scores_average= np.mean(np.array(scores), axis=0)

    return model, scores_average, pred, bm

def r2_get(target, pred):
    SST = np.sum((target - np.mean(target, 0)) ** 2, 0)
    SSR = np.sum((target - pred) ** 2, 0)
    R2 = 1 - SSR / SST

    if R2.shape == 1:
        R2X = R2[0].astype(float)
        R2Y = R2[0].astype(float)
    else:
        R2X = R2[0].astype(float)
        R2Y = R2[1].astype(float)
    R2Mean = np.mean(R2).astype(float)

    return R2X, R2Y, R2Mean

def cc_get(target, pred):
    # Convert tensors to numpy arrays for compatibility with pearsonr
    preds_np = pred
    labels_np = target

    # Calculate Pearson Correlation Coefficient
    correlation_coefficient, _ = pearsonr(preds_np, labels_np)

    return correlation_coefficient.mean()

def mse_get(target, preds):
    return np.mean((preds - target) ** 2)

def rmse_get(target, preds):
    return np.sqrt(np.mean((preds - target) ** 2))

