import torch
import numpy as np
import stork
from stork.generators import StandardGenerator

from .utils.plotting import plot_activity_snapshot
import copy
import matplotlib.pyplot as plt
import random
import sys

# Custom loss stacks
from .custom.loss import MeanSquareError, RootMeanSquareError, MeanAbsoluteError, HuberLoss, RootMeanSquareError_with_MaxOverTimeCrossEntropy
from .custom.readout import CustomReadoutGroup, AverageReadouts
from .get_crossSet_data import get_dataloader_foundation_crossSet

import logging
logger = logging.getLogger(__name__)

def get_optimizer(cfg, dtype):
    
    opt_kwargs = {
        "lr": cfg.training.lr
        }
    
    if cfg.training.optimizer == "adam":
        opt = torch.optim.Adam
        opt_kwargs["eps"] = 1e-4 if dtype == torch.float16 else 1e-8
        
    elif cfg.training.optimizer == "SMORMS3":
        opt = stork.optimizers.SMORMS3
        opt_kwargs["eps"] = 1e-5 if dtype == torch.float16 else 1e-16
        
    return opt, opt_kwargs


def _choose_loss(cfg):

    if cfg.training.loss == "MSE":
        loss_class = MeanSquareError
    elif cfg.training.loss == "RMSE":
        loss_class = RootMeanSquareError
    elif cfg.training.loss == "MAE":
        loss_class = MeanAbsoluteError
    elif cfg.training.loss == "Huber":
        loss_class = HuberLoss
    else:
        raise ValueError(f"Unknown loss: {cfg.training.loss}")

    return loss_class


def get_train_loss(cfg):

    loss_class = _choose_loss(cfg)
    
    args = {}
    
    # Mask early timesteps
    if cfg.training.mask_early_timesteps:

        nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
        mask = torch.ones(nb_time_steps)
        mask[: cfg.training.nb_masked_timesteps] = 0
        mask = torch.stack([mask, mask], dim=1)
        
        args["mask"] = mask

    return loss_class(**args)


def get_lr_scheduler(cfg, opt, pretrainFlag='train'):
    
    if cfg.training.lr_scheduler == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR
        if pretrainFlag == 'pretrain':
            scheduler_kwargs = {"T_max": cfg.training.nb_epochs_pretrain}
        elif pretrainFlag == 'train':
            scheduler_kwargs = {"T_max": cfg.training.nb_epochs_train}
        else:
            raise ValueError(f"Unknown pretrainFlag: {pretrainFlag}")
    else:
        raise ValueError(f"Unknown lr_scheduler: {cfg.training.lr_scheduler}")
    
    return scheduler, scheduler_kwargs


def configure_model(model, cfg, dtype, num_classes=None):

    if num_classes is not None:
        loss_stack = RootMeanSquareError_with_MaxOverTimeCrossEntropy(num_classes=num_classes)
    else:
        loss_stack = get_train_loss(cfg)# 获取损失函数
    opt, opt_kwargs = get_optimizer(cfg, dtype)# 获取优化器



    # 获取学习率调度器lr_scheduler，如果配置中指定了学习率调度器，则根据配置和优化器获取调度器及其参数；否则，将调度器设置为 None。
    if cfg.training.lr_scheduler is not None:
        if hasattr(model,'pretrainFlag'):
            scheduler, scheduler_kwargs = get_lr_scheduler(cfg, opt, model.pretrainFlag)
        else:
            scheduler, scheduler_kwargs = get_lr_scheduler(cfg, opt)
    else:
        scheduler = None
        scheduler_kwargs = None

    # 定义一个函数 worker_init_fn，用于在数据加载器的每个工作进程中初始化随机种子，确保实验的可重复性。
    def worker_init_fn(worker_id):
        np.random.seed(cfg.seed + worker_id)
        random.seed(cfg.seed + worker_id)

    if hasattr(cfg.model, "output_feedback") and cfg.model.output_feedback:
        output_feedback = cfg.model.output_feedback
    elif hasattr(cfg.model, "session_classfication") and cfg.model.session_classfication:
        output_feedback = cfg.model.session_classfication
    else:
        output_feedback=False

    if cfg.model.pretrain_forze:
        if hasattr(model,'pretrainFlag') and model.pretrainFlag == 'pretrain':
            pretrain_forze = True
        else:
            pretrain_forze = False

    for g in model.groups:
        if isinstance(g, AverageReadouts):
            model_output=g
        elif isinstance(g, CustomReadoutGroup):
            model_output=g

    # Configure model 配置模型
    # 指定多卡联合
    # After constructing your model instance
    if torch.cuda.is_available() and hasattr(cfg, 'multi_cuda') and cfg.multi_cuda and 'cuda' in cfg.device:
        assert len(cfg.gpu_ids) > 1, "Not enough GPUs are specified, gpu_ids need to be greater than 2, please modify cfg.gpu_ids."
        assert torch.cuda.device_count() >= len(cfg.gpu_ids), "Not enough GPUs available"
        logger.info("using multi gpu...")

        # 根据操作系统（Windows 或 Linux）选择数据生成器 StandardGenerator。在 Linux 系统中，额外传入 worker_init_fn 以确保随机种子的初始化。
        if sys.platform == "win32":
            print("Windows")
            generator = StandardGenerator(
                nb_workers=cfg.nb_workers,
                persistent_workers=False)  # Windows
        elif sys.platform == "linux":
            print("Linux")
            generator = StandardGenerator(
                nb_workers=cfg.nb_workers,
                worker_init_fn=worker_init_fn,
                persistent_workers=True)  # Linux

        model.configure(
            input=model.groups[0],
            output=model_output,
            loss_stack=loss_stack,
            generator=generator,
            optimizer=opt,
            optimizer_kwargs=opt_kwargs,
            scheduler=scheduler,
            scheduler_kwargs=scheduler_kwargs,
            time_step=cfg.data.dt,
            multi_cuda=cfg.multi_cuda,
            device_ids=cfg.gpu_ids,
            earlystop=cfg.training.earlystop,
            earlystop_patience=cfg.training.earlystop_patience,
            earlystop_min_delta=cfg.training.earlystop_min_delta,
            earlystop_restore_best_weights=cfg.training.restore_best_weights,
            earlystop_min_ep = cfg.training.earlystop_min_epochs,
            earlystop_ema_alpha= cfg.training.earlystop_ema_alpha,
            earlystop_restart_epochs= cfg.training.earlystop_restart_epochs,
            earlystop_restart_lr_factor= cfg.training.earlystop_restart_lr_factor,
            output_feedback=output_feedback,
            pretrain_forze=cfg.model.pretrain_forze,
        )
    else:

        # 根据操作系统（Windows 或 Linux）选择数据生成器 StandardGenerator。在 Linux 系统中，额外传入 worker_init_fn 以确保随机种子的初始化。
        if sys.platform == "win32":
            print("Windows")
            generator = StandardGenerator(nb_workers=cfg.nb_workers, persistent_workers=False)  # Windows
        elif sys.platform == "linux":
            print("Linux")
            # generator = StandardGenerator(nb_workers=cfg.nb_workers, worker_init_fn=worker_init_fn, persistent_workers=False)  # Linux
            generator = StandardGenerator(nb_workers=cfg.nb_workers, worker_init_fn=worker_init_fn, persistent_workers=False)  # Linux

        model.configure(
            input=model.groups[0],
            output=model_output,
            loss_stack=loss_stack,
            generator=generator,
            optimizer=opt,
            optimizer_kwargs=opt_kwargs,
            scheduler=scheduler,
            scheduler_kwargs=scheduler_kwargs,
            time_step=cfg.data.dt,
            earlystop=cfg.training.earlystop,
            earlystop_patience=cfg.training.earlystop_patience,
            earlystop_min_delta=cfg.training.earlystop_min_delta,
            earlystop_restore_best_weights=cfg.training.restore_best_weights,
            earlystop_min_ep = cfg.training.earlystop_min_epochs,
            earlystop_ema_alpha=cfg.training.earlystop_ema_alpha,
            earlystop_restart_epochs=cfg.training.earlystop_restart_epochs,
            earlystop_restart_lr_factor=cfg.training.earlystop_restart_lr_factor,
            output_feedback=output_feedback,
            pretrain_forze=cfg.model.pretrain_forze,
        )

    # if model.pretrainFlag == "pretrain":
    #     model.multi_BN_set(cfg)
    #     model.multiHidden_set(cfg)
    # else:
    #     model.multiBN = False
    #     model.multiHidden = False
    # if not cfg.model.output_BN:
    #     print("Removing output batch normalization layer...")
    #     if cfg.model.multiple_readouts:
    #         for c in model.connections[-5:]:
    #             if hasattr(c, "bn") and c.bn is not None:
    #                 c.bn = None
    #     else:
    #         c = model.connections[-1]
    #         if hasattr(c, "bn") and c.bn is not None:
    #             c.bn = None
    return model


def train_validate_model(
    model,
    cfg,
    train_data,
    valid_data,
    nb_epochs,
    verbose=True,
    snapshot_prefix="",
    random_Trial_for_plot=False,
):
    if not cfg.model.output_BN:
        assert model.connections[-1].bn is None, "The last connection should have a batch normalization layer."
    if cfg.plotting.plot_snapshots:
        fix, ax = plot_activity_snapshot(
            model, valid_data, save_path=snapshot_prefix + "snapshot_before.png",
            random_Trial_for_plot=random_Trial_for_plot
        )

    if cfg.model.step_training and cfg.model.self_and_crossAttention:
        assert cfg.model.self_and_crossAttention, "Step training is only supported for self_and_crossAttention models."

        for i in range(cfg.model.step_num+1):
            model.earlystop_patience = max(40, int(cfg.training.earlystop_patience * (i/cfg.model.step_num)))
            model.input_group.set_epoch_step(i, cfg.model.step_num)  # 设置当前训练步数
            # 初始化优化器
            opt, opt_kwargs = get_optimizer(cfg, model.dtype)  # 获取优化器
            # 初始化优化器学习率调度器lr_scheduler，如果配置中指定了学习率调度器，则根据配置和优化器获取调度器及其参数；否则，将调度器设置为 None。
            if cfg.training.lr_scheduler is not None:
                if hasattr(model, 'pretrainFlag'):
                    scheduler, scheduler_kwargs = get_lr_scheduler(cfg, opt, model.pretrainFlag)
                else:
                    scheduler, scheduler_kwargs = get_lr_scheduler(cfg, opt)
            else:
                scheduler = None
                scheduler_kwargs = None

            model.configure_optimizer(opt, opt_kwargs)  # 配置优化器
            model.configure_scheduler(scheduler, scheduler_kwargs)  # 配置学习率调度器

            history = model.fit_validate_step(
                train_data,
                valid_data,
                nb_epochs=nb_epochs,
                verbose=verbose,
            )
            # print(history)
    else:
        history = model.fit_validate(
            train_data,
            valid_data,
            nb_epochs=nb_epochs,
            verbose=verbose,
        )

    # if model.pretrainFlag == 'train':
    #     if cfg.model.Attention_qkv == "conv" and cfg.model.Repconv:
    #         for c in model.connections:
    #             if isinstance(c.op, RepVGGplusBlock1d):
    #                 c.op.switch_to_deploy()
    #     for c in model.connections:
    #         if hasattr(c, "bn") and c.bn is not None:
    #             c.switch_to_deploy()

    if cfg.plotting.plot_snapshots:
        fig, ax = plot_activity_snapshot(
            model,
            valid_data,
            save_path=snapshot_prefix + "snapshot_after_e{}.png".format(nb_epochs),
            random_Trial_for_plot=random_Trial_for_plot,
        )

    return model, history

def train_validate_model_step_by_step(
    model,
    cfg,
    train_data,
    valid_data,
    nb_epochs,
    filenames,
    finetune_monkeyname,
    stateFlag,
    verbose=True,
    snapshot_prefix="",
    random_Trial_for_plot=False,
):

    if cfg.plotting.plot_snapshots:
        fix, ax = plot_activity_snapshot(
            model, valid_data, save_path=snapshot_prefix + "snapshot_before.png",
            random_Trial_for_plot=random_Trial_for_plot
        )

    # history = model.fit_validate_step(
    #     train_data,
    #     valid_data,
    #     nb_epochs=2,
    #     verbose=verbose,
    # )
    # print(history)

    max_sample_dur = cfg.data.sample_duration
    min_sample_dur = 0.4

    for i in range(cfg.model.step_num+1):
        cfg.data.sample_duration = (max_sample_dur - i * (max_sample_dur - min_sample_dur) / cfg.model.step_num)
        if cfg.data.sample_duration<1.2:
            cfg.data.continuous_trial=False
        dataloader = get_dataloader_foundation_crossSet(cfg, dtype=model.dtype)
        print('dataloader.sample_duration: ', dataloader.sample_duration)


        if stateFlag == "pretrain":
            train_dat, val_dat, _ = dataloader.get_multiple_set_data(
                filenames, nb_inputs=model.nb_inputs, with_S1=cfg.with_S1
            )
        elif stateFlag == "fine-tune" or "finetune":
            train_dat, val_dat, _ = dataloader.get_single_session_data(
                filenames,
                monkeyname=finetune_monkeyname,
                nb_inputs=model.nb_inputs,
            )

        nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
        model.set_nb_steps(nb_time_steps)
        model = configure_model(model, cfg, model.dtype)

        model.earlystop_patience = max(40, int(cfg.training.earlystop_patience * (i/cfg.model.step_num)))

        # # 初始化优化器
        # opt, opt_kwargs = get_optimizer(cfg, model.dtype)  # 获取优化器
        # # 初始化优化器学习率调度器lr_scheduler，如果配置中指定了学习率调度器，则根据配置和优化器获取调度器及其参数；否则，将调度器设置为 None。
        # if cfg.training.lr_scheduler is not None:
        #     if hasattr(model, 'pretrainFlag'):
        #         scheduler, scheduler_kwargs = get_lr_scheduler(cfg, opt, model.pretrainFlag)
        #     else:
        #         scheduler, scheduler_kwargs = get_lr_scheduler(cfg, opt)
        # else:
        #     scheduler = None
        #     scheduler_kwargs = None
        #
        # model.configure_optimizer(opt, opt_kwargs)  # 配置优化器
        # model.configure_scheduler(scheduler, scheduler_kwargs)  # 配置学习率调度器

        history = model.fit_validate(
            train_dat,
            val_dat,
            nb_epochs=nb_epochs,
            verbose=verbose,
        )
        # print(history)

    cfg.data.sample_duration=max_sample_dur
    nb_time_steps = int(max_sample_dur / cfg.data.dt)
    model.set_nb_steps(nb_time_steps)
    if cfg.plotting.plot_snapshots:
        fig, ax = plot_activity_snapshot(
            model,
            valid_data,
            save_path=snapshot_prefix + "snapshot_after_e{}.png".format(nb_epochs),
            random_Trial_for_plot=random_Trial_for_plot,
        )

    return model, history

def train_validate_model_multiBN(
    model,
    cfg,
    train_data,
    valid_data,
    nb_epochs,
    verbose=True,
    snapshot_prefix="",
    random_Trial_for_plot=False,
):

    if cfg.plotting.plot_snapshots:
        for key in valid_data:
            if model.multiBN:
                model.BN_switch(key)
            if model.multiHidden:
                model.hidden_switch(key)
            fix, ax = plot_activity_snapshot(
                model, valid_data[key], save_path=snapshot_prefix + key +":snapshot_before.png",
                random_Trial_for_plot=random_Trial_for_plot
            )


    history = model.fit_validate(
        train_data,
        valid_data,
        nb_epochs=nb_epochs,
        verbose=verbose,
    )


    if cfg.plotting.plot_snapshots:
        for key in valid_data:
            fig, ax = plot_activity_snapshot(
                model,
                valid_data[key],
                save_path=snapshot_prefix + key + ":snapshot_after_e{}.png".format(nb_epochs),
                random_Trial_for_plot=random_Trial_for_plot,
            )

    return model, history

# Prune the model by removing the smallest weights
def prune_model(model, prune_percentage):
    mask = {}
    for name, param in model.named_parameters():
        # print(name, ":", param)
        if 'weight' in name:
            _, indices = torch.topk(torch.abs(param.data.flatten()), int(param.data.numel() * (prune_percentage)),
                                    largest=False)
            mask[name] = torch.ones_like(param.data).flatten()
            mask[name][indices] = 0
            mask[name] = mask[name].view_as(param.data)
            param.data.mul_(mask[name])
    return mask


def train_validate_model_pruning(
    model,
    cfg,
    train_data,
    valid_data,
    nb_epochs,
    verbose=True,
    snapshot_prefix="",
    mask=None,
    random_Trial_for_plot=False,
    ):

    if cfg.plotting.plot_snapshots:
        fix, ax = plot_activity_snapshot(
            model, valid_data, save_path=snapshot_prefix + "snapshot_before.png",
            random_Trial_for_plot=random_Trial_for_plot,
        )
    
    history = model.fit_validate_masked(
        train_data,
        valid_data,
        nb_epochs=nb_epochs,
        verbose=verbose,
        mask=mask,
    )

    if cfg.plotting.plot_snapshots:
        fig, ax = plot_activity_snapshot(
            model,
            valid_data,
            save_path=snapshot_prefix + "snapshot_after_e{}.png".format(nb_epochs),
            random_Trial_for_plot=random_Trial_for_plot,
        )

    return model, history


def prune_retrain_model(
        model,
        prune_percentage,
        cfg,
        train_data,
        valid_data,
        logger,
        nb_epochs_retrain=50,
        is_pruning_ver=False,
        session_name='None',
):
    # Calculate the connection sparsity
    def calculate_con_sparsity(model):
        total_params = 0
        zero_params = 0
        for name, param in model.named_parameters():
            if 'weight' in name:
                total_params += param.numel()
                zero_params += (param == 0).sum().item()
        return zero_params / total_params
    
    mask = prune_model(model, prune_percentage)
    logger.info(f"Pruning percentage: {prune_percentage:.2f}")
    logger.info("Retrain the pruned model while keeping the pruned weights as zero...")
    
    # Re-configure optimizer and learning rate scheduler
    model.configure_optimizer(model.optimizer_class, model.optimizer_kwargs)
    new_scheduler_kwargs = {"T_max": nb_epochs_retrain}
    model.configure_scheduler(model.scheduler_class, new_scheduler_kwargs)
    
    model, history = train_validate_model_pruning(
        model,
        cfg,
        train_data,
        valid_data,
        nb_epochs_retrain,
        verbose=is_pruning_ver, # cfg.training.verbose,
        snapshot_prefix=session_name + ' Pruned ' + str(np.round(prune_percentage,2)) + "_",
        mask=mask,
    )
    r2_train_after_retraining = history['r2'][-1]
    r2_val_after_retraining = history['val_r2'][-1]
    sparsity_after_retraining = calculate_con_sparsity(model)

    return model, r2_train_after_retraining, r2_val_after_retraining, sparsity_after_retraining


def prune_retrain_model_iterate(
        ori_model,
        cfg,
        train_data,
        valid_data,
        logger,
        r2_train_before_pruned,
        r2_val_before_pruned,
        nb_epochs_retrain=50,
        prune_percentage_start=0.1,
        tolerance=0.03, 
        prune_precision=[0.1], # go through one by one, should decrease gradually, for example, [0.1, 0.01]
        max_prune_percentage=1.0,
        is_plot_pruning=True,
        is_pruning_ver=False,
        session_name='None',
        pruning_plot_prefix=''
):

    sparsity_scores = []
    r2_trains = []
    r2_vals = []
    prune_percentages = []
    precision_idx = 0
    logger.info(f"r2 performance of unpruned model: training: {r2_train_before_pruned:.4f}, validation: {r2_val_before_pruned:.4f}")
    logger.info("Pruning model iteratively...")
    prune_percentage = prune_percentage_start
    ori_state_dict = copy.deepcopy(ori_model.state_dict())
    final_state_dict = copy.deepcopy(ori_model.state_dict()) # to be updated in the loop, unless pruning can't work at all
    pruned_model = ori_model # Note: just rename

    while prune_percentage <= max_prune_percentage:
        pruned_model.load_state_dict(ori_state_dict)

        (pruned_model,
         r2_train_after_retraining,
         r2_val_after_retraining,
         sparsity_after_retraining
         ) = prune_retrain_model(
                                pruned_model,
                                prune_percentage,
                                cfg,
                                train_data,
                                valid_data,
                                logger,
                                nb_epochs_retrain,
                                is_pruning_ver,
                                session_name
        )


        if r2_train_after_retraining-r2_train_before_pruned < -np.abs(tolerance*r2_train_before_pruned): # or r2_val_after_retraining-r2_val_before_pruned < -np.abs(tolerance*r2_val_before_pruned):
            if precision_idx < len(prune_precision)-1:
                prune_percentage -= prune_precision[precision_idx]
                precision_idx += 1
                prune_percentage += prune_precision[precision_idx]
            else:
                break
            continue

        r2_trains.append(r2_train_after_retraining)
        r2_vals.append(r2_val_after_retraining)
        sparsity_scores.append(sparsity_after_retraining)
        logger.info(f"Pruning percentage: {prune_percentage:.2f}, Sparsity: {sparsity_after_retraining:.4f}, r2 after retraining: training: {r2_train_after_retraining:.4f}, validation: {r2_val_after_retraining:.4f}")
        prune_percentages.append(prune_percentage)

        final_state_dict = copy.deepcopy(pruned_model.state_dict())
        prune_percentage += prune_precision[precision_idx]

    if is_plot_pruning: 

        # Plot the results
        fig = plt.figure(figsize=(12, 4))

        plt.subplot(1, 2, 1)
        plt.scatter(prune_percentages, r2_trains, marker='o', label='trained_pruned')
        plt.scatter(prune_percentages, r2_vals, marker='o', label='val_pruned')
        plt.axhline(y=r2_train_before_pruned, color='b', linestyle='--', label='trained_Unpruned')
        plt.axhline(y=r2_val_before_pruned, color='r', linestyle='--', label='val_Unpruned')
        plt.xlabel("Pruning Percentage")
        plt.ylabel("r2_mean")
        plt.title("Impact of Pruning and Retraining on Regression Performance")
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.scatter(prune_percentages, sparsity_scores, marker='o')
        plt.xlabel("Pruning Percentage")
        plt.ylabel("Sparsity")
        plt.title("Connection Sparsity vs. Pruning Percentage")

        fig.savefig(pruning_plot_prefix+' Pruning.png', dpi=250)

    if prune_percentages:
        logger.info(f"Maximum pruning percentage that retains the performance of the unpruned model: {prune_percentages[-1]:.2f}")
    else:
        logger.info(f"No suitable pruned model found, try smaller starting pruning percentage.")

    pruned_model.load_state_dict(final_state_dict)

    return pruned_model