# SNN Models
from SNN.Layers.NeuronConfig import NeuronConfig
import SNN.models as snn_models

# Options and Datasets
from AbstractModels import Options

from AbstractModels.SpikingConvolutionNetwork import SpikingConvolutionNetwork

import torch
from torch.utils.data import Subset, DataLoader, Dataset


import argparse

import json

# Used to dynamically import classes for easier model addition and removal
import importlib

encoders = Options.encoders
decoders = Options.decoders
static_datasets = Options.datasets
neuromorphic_datasets = Options.neuromorphic_datasets
collate_fns = Options.collate_fns
dataset_transforms = Options.dataset_transforms
dataset_classes = Options.dataset_classes
optimizers = Options.optimizers
losses = Options.losses
schedulers = Options.schedulers
surrogate_gradients = Options.surrogate_gradients

class Config:
    def __init__(self):
        self.args = self.parse_args()

        self.is_snn: bool = any([self.args.model.lower() == model.lower() for model in snn_models.classification_all ])

        if self.args.dataset in neuromorphic_datasets.keys():
            self.args.encoder = 'identity'

    def initialize_model(self) -> torch.nn.Module:
        self.model = self._load_model()        
        return self.model

    def parse_args(self):
        models: list = snn_models.classification_all.copy()

        models = [model.lower() for model in models]

        self.dataset_choices = static_datasets.copy()
        self.dataset_choices.update(neuromorphic_datasets)

        parser = argparse.ArgumentParser()

        parser.add_argument('--model', type=str, default='itliflenet5', choices=list(models), help='Model to train')
        parser.add_argument('--dataset', type=str, default='fmnist', choices=list(self.dataset_choices.keys()), help='Dataset to train on')
        parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train for')
        parser.add_argument('--checkpoint', type=int, default=-1, help='Number of epochs to save model a checkpoint')
        parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
        parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
        parser.add_argument('--optimizer', type=str, default='adam', choices=list(optimizers.keys()), help='Optimizer to use')
        parser.add_argument('--loss', type=str, default='crossentropy', choices=list(losses.keys()), help='Loss function to use')
        parser.add_argument('--num_workers', type=int, default=4, help='Number of workers to use for data loading')
        parser.add_argument('--seq_length', type=int, default=2, help='Sequence length')
        parser.add_argument('--sg', type=str, default='rectangle', choices=surrogate_gradients, help='Surrogate gradient method')
        parser.add_argument('--input_scale', type=float, default=1, help='Input scale')
        parser.add_argument('--encoder', type=str, default='copy', choices=list(encoders.keys()), help='Encoder to encode static data into neuromorphic format. If data is already in neuromorphic format, use identity')
        parser.add_argument('--decoder', type=str, default='mean', choices=list(decoders.keys()), help='Decoder to convert neural output to a usable format')
        parser.add_argument('--subset', type=int, default=-1, help='Train on a small subset of the data')
        parser.add_argument('--load_weights', type=str, default=None, help='Load weights from file')
        parser.add_argument('--validate', type=bool, default=False, help='Validate model instead of training')
        parser.add_argument('--l2', type=float, default=0.0, help='L2 regularization lambda value')
        parser.add_argument('--momentum', type=float, default=0.0, help='Momentum value')
        parser.add_argument('--scheduler', type=str, default=None, choices=list(schedulers.keys()), help='Scheduler to use')
        parser.add_argument('--eta_min', type=float, default=0, help='Minimum learning rate multiplier for cosine annealing scheduler: eta_min = lr * eta_min')
        parser.add_argument('--name', type=str, default=None, help='Name for model and associated metadata. Used to distinguish between different runs and models')
        parser.add_argument('--resume', type=str, default='False', help='Resume training from a checkpoint')
        parser.add_argument('--nesterov', type=str, default='False', help='Use Nesterov momentum')
        parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
        parser.add_argument('--energy', type=str, default='False', help='Calculate energy consumption')
        parser.add_argument('--cutmix', type=str, default='False', help='Use CutMix augmentation')
        parser.add_argument('--grad_steps', type=int, default=1, help='Number of gradient steps to take before updating weights')

        return parser.parse_args()

    def cutmix(self) -> bool:
        return self.args.cutmix.lower() == 'true'

    def get_seed(self) -> int:
        return self.args.seed

    def get_name(self) -> str:
        return self.args.name

    def get_epochs(self) -> int:
        return self.args.epochs
    
    def get_timesteps(self) -> int:
        return self.args.seq_length
    
    def calculate_energy_consumption(self) -> bool:
        return self.args.energy.lower() == 'true'
    
    def get_checkpoint_period(self) -> int:
        return -1 if self.args.checkpoint < 0 else self.args.checkpoint

    def get_dataset(self, train=True) -> Dataset:
        transform = self.get_transform(train)
        data = self.dataset_choices[self.args.dataset](
            root='./data/', 
            train=train,
            transform=transform
        )
            
        if 0 < self.args.subset < len(data):
            return Subset(data, range(0, self.args.subset))
        return data

    def get_transform(self, train: str) -> torch.nn.Module:
        if self.args.dataset in dataset_transforms.keys():
            return dataset_transforms[self.args.dataset]
        elif f'{self.args.dataset}_train' in dataset_transforms.keys() and train:
            return dataset_transforms[f'{self.args.dataset}_train']
        elif f'{self.args.dataset}_val' in dataset_transforms.keys() and not train:
            return dataset_transforms[f'{self.args.dataset}_val']
        raise ValueError(f"Transform for dataset {self.args.dataset} not found.")

    def get_dataloader(self, train: int = True) -> DataLoader:
        return DataLoader(
            self.get_dataset(train=train), 
            batch_size=self.args.batch_size, 
            shuffle=True,
            num_workers=self.args.num_workers,
            pin_memory=True,
            collate_fn=collate_fns[self.args.dataset],
            drop_last=train,
        )
    
    def get_optimizer(self) -> torch.optim.Optimizer:
        if self.args.optimizer == 'lamb':
            param_groups = self.param_groups_lrd(
                self.model,
                self.args.l2,
                layer_decay=1.0,
            )
            return optimizers[self.args.optimizer](
                param_groups, 
                lr=self.args.lr,
                trust_clip=True
            )
        if self.args.optimizer == 'sgd':
            return optimizers[self.args.optimizer](
                self.model.parameters(), 
                lr=self.args.lr,
                weight_decay=self.args.l2, 
                momentum=self.args.momentum,
                nesterov=self.args.nesterov.lower() == 'true'
            )
        return optimizers[self.args.optimizer](
            self.model.parameters(), 
            lr=self.args.lr, 
            weight_decay=self.args.l2
        )
    
    def get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
        if self.args.scheduler == 'None' or self.args.scheduler is None:
            return None
        if self.args.scheduler == 'step':
            return schedulers[self.args.scheduler](
                optimizer=optimizer, 
                step_size=100,
                gamma=0.1
            )
        if self.args.scheduler == 'multistep':
            return schedulers[self.args.scheduler](
                optimizer=optimizer, 
                gamma=0.1,
                milestones=[150, 225, 300], 
            )
        if self.args.scheduler == 'cosine':
            return schedulers[self.args.scheduler](
                optimizer=optimizer,
                # T_max=self.args.epochs,
                T_max=100,
                eta_min=0.0
            )
        return ValueError(f"Scheduler {self.args.scheduler} not found.")
    
    def get_criterion(self) -> torch.nn.Module:
        return losses[self.args.loss]()
    
    def get_gradient_steps(self) -> int:
        return self.args.grad_steps

    def get_lr(self) -> float:
        return self.args.lr

    def validate(self) -> bool:
        return self.args.validate
    
    def resume(self) -> bool:
        return self.args.resume.lower() == 'true'

    def get_model_dir(self) -> str:
        return "./SNN/"
    
    def _load_model(self) -> torch.nn.Module:
        snn_config = NeuronConfig(
            timesteps=self.args.seq_length
        )
        snn_model_list = snn_models.classification_all.copy()
        model = self.get_class_from_string(
            self.args.model, 
            snn_model_list, 
            [snn_models.classification]
        )
        return SpikingConvolutionNetwork( 
            encoder=encoders[self.args.encoder](seq_length=self.args.seq_length, config=snn_config),
            snn=model(num_classes=dataset_classes[self.args.dataset], config=snn_config, method=self.args.sg),
            decoder=decoders[self.args.decoder],
            seq_length=self.args.seq_length,
            input_scale=self.args.input_scale
        )

    def load_weights(self) -> str:
        return self.args.load_weights

    def save(self, path: str) -> None:
        with open(path, 'w') as f:
            json.dump(self.args.__dict__, f, indent=4)

    def get_class_from_string(self, class_name: str, class_list: list, class_import: list) -> type:
        try:
            for c in class_list:
                if class_name.lower() == c.lower():
                    for import_module in class_import:
                        if c in import_module.__all__:
                            module = importlib.import_module(import_module.__name__)
                            return getattr(module, c)
            else:
                raise ValueError(f"Class '{class_name}' not found in {class_list}.\nMake sure you define the class in the correct module and __all__ variable.")
        except AttributeError:
            raise ValueError(f"Class {class_name} is not defined. Available classes include {class_list}.\nMake sure you define the class in the correct module and __all__ variable.")


    def param_groups_lrd(
        self, model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75
    ):
        """Taken from https://github.com/BICLab/Spike-Driven-Transformer-V2/blob/main/classification/util/lr_decay_spikformer.py"""
        param_group_names = {}
        param_groups = {}

        num_layers = len(model.model.model.block3) + len(model.model.model.block4) + 1

        layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))

        for n, p in model.named_parameters():
            if not p.requires_grad:  # 仅针对需要利用梯度进行更新的参数
                continue

            # no decay: all 1D parameters and model specific ones
            if p.ndim == 1 or n in no_weight_decay_list:
                g_decay = "no_decay"
                this_decay = 0.0
            else:
                g_decay = "decay"
                this_decay = weight_decay

            layer_id = get_layer_id_for_vit(n, num_layers)
            group_name = "layer_%d_%s" % (layer_id, g_decay)

            if group_name not in param_group_names:
                this_scale = layer_scales[layer_id]

                param_group_names[group_name] = {
                    "lr_scale": this_scale,
                    "weight_decay": this_decay,
                    "params": [],
                }
                param_groups[group_name] = {
                    "lr_scale": this_scale,
                    "weight_decay": this_decay,
                    "params": [],
                }

            param_group_names[group_name]["params"].append(n)
            param_groups[group_name]["params"].append(p)

        # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))

        return list(param_groups.values())


def get_layer_id_for_vit(name, num_layers):
    """Taken from https://github.com/BICLab/Spike-Driven-Transformer-V2/blob/main/classification/util/lr_decay_spikformer.py"""
    if name in ["cls_token", "pos_embed"]:
        return 0
    elif name.startswith("patch_embed"):
        return 0
    elif name.startswith("block"):
        # return int(name.split('.')[1]) + 1
        return num_layers
    else:
        return num_layers