# SNN Models
from SNN.Layers.NeuronConfig import NeuronConfig
import SNN.models as snn_models

# Options and Datasets
from AbstractModels import Options

import torch
from torch.utils.data import Subset, DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

import argparse

import json

# Used to dynamically import classes for easier model addition and removal
import importlib

from .util.init_params import split_weights
from .util.scheduler import GradualWarmupScheduler

encoders = Options.encoders
decoders = Options.decoders
datasets = Options.datasets
neuromorphic_datasets = Options.neuromorphic_datasets
collate_fns = Options.collate_fns
dataset_transforms = Options.dataset_transforms
dataset_classes = Options.dataset_classes
optimizers = Options.optimizers
losses = Options.losses
schedulers = Options.schedulers
surrogate_gradients = Options.surrogate_gradients

class Config:
    def __init__(self):
        self.args = self.parse_args()
        all_snn_models: list = snn_models.classification_all.copy()

        self.is_snn: bool = any([self.args.model.lower() == model.lower() for model in all_snn_models])

    def initialize_model(self, rank: None | int = None) -> torch.nn.Module:
        self.model = self._load_model()        
        if self.use_ddp():
            self.model.to(rank)
            self.model.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model.model)
            self.model.model = torch.nn.parallel.DistributedDataParallel(self.model.model, device_ids=[rank])
        return self.model

    def parse_args(self):
        models: list = snn_models.classification_all.copy()

        models = [model.lower() for model in models]

        dataset_choices = datasets.copy()
        dataset_choices.update(neuromorphic_datasets)

        parser = argparse.ArgumentParser()

        parser.add_argument('--model', type=str, default='lenet5', choices=list(models), help='Model to train')
        parser.add_argument('--dataset', type=str, default='mnist', choices=list(dataset_choices.keys()), help='Dataset to train on')
        parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train for')
        parser.add_argument('--checkpoint', type=int, default=-1, help='Number of epochs to save model a checkpoint')
        parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
        parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
        parser.add_argument('--optimizer', type=str, default='adam', choices=list(optimizers.keys()), help='Optimizer to use')
        parser.add_argument('--loss', type=str, default='crossentropy', choices=list(losses.keys()), help='Loss function to use')
        parser.add_argument('--num_workers', type=int, default=4, help='Number of workers to use for data loading')
        parser.add_argument('--seq_length', type=int, default=16, help='Sequence length')
        parser.add_argument('--sg', type=str, default='rectangle', choices=surrogate_gradients, help='Surrogate gradient method')
        parser.add_argument('--encoder', type=str, default='copy', choices=list(encoders.keys()), help='Encoder to use')
        parser.add_argument('--decoder', type=str, default='mean', choices=list(decoders.keys()), help='Decoder to use')
        parser.add_argument('--subset', type=int, default=-1, help='Train on a small subset of the data')
        parser.add_argument('--load_weights', type=str, default=None, help='Load weights from file')
        parser.add_argument('--validate', type=bool, default=False, help='Validate model instead of training')
        parser.add_argument('--l2', type=float, default=0.0, help='L2 regularization lambda value')
        parser.add_argument('--momentum', type=float, default=0.0, help='Momentum value')
        parser.add_argument('--scheduler', type=str, default=None, choices=list(schedulers.keys()), help='Scheduler to use')
        parser.add_argument('--name', type=str, default=None, help='Name for model and associated metadata. Used to distinguish between different runs and models')
        parser.add_argument('--resume', type=str, default='False', help='Resume training from a checkpoint')
        parser.add_argument('--nesterov', type=str, default='False', help='Use Nesterov momentum')
        parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
        parser.add_argument('--ddp', type=str, default='False', help='Use Distributed Data Parallel Training')
        parser.add_argument('--label_smoothing', type=float, default=0.0, help='Label smoothing value')
        parser.add_argument('--poly_degree', type=int, default=3, help='Degree of polynomial for learnable membrane')
        parser.add_argument('--warmup', type=int, default=-1, help='Number of epochs to warm up the learning rate')

        return parser.parse_args()

    def get_seed(self) -> int:
        return self.args.seed

    def get_name(self) -> str:
        return self.args.name

    def get_epochs(self) -> int:
        return self.args.epochs
    
    def get_timesteps(self) -> int:
        return self.args.seq_length
    
    def use_ddp(self) -> bool:
        return self.args.ddp.lower() == 'true'
    
    def calculate_energy_consumption(self) -> bool:
        return self.args.energy.lower() == 'true'
    
    def get_checkpoint_period(self) -> int:
        return -1 if self.args.checkpoint < 0 else self.args.checkpoint

    def get_dataset(self, train=True) -> Dataset:
        transform = self.get_transform(train)
        if self.args.dataset in neuromorphic_datasets.keys():
            data = neuromorphic_datasets[self.args.dataset](
                root='./data/', 
                train=train,
                transform=transform,
                download=True
            )
        elif self.args.dataset in datasets.keys():
            data = datasets[self.args.dataset](
                root='./data/', 
                train=train, 
                transform=transform,
                download=True
            )
            
        if 0 < self.args.subset < len(data):
            return Subset(data, range(0, self.args.subset))
        return data

    def get_transform(self, train: str) -> torch.nn.Module:
        if self.args.dataset in dataset_transforms.keys():
            return dataset_transforms[self.args.dataset]
        elif f'{self.args.dataset}_train' in dataset_transforms.keys() and train:
            return dataset_transforms[f'{self.args.dataset}_train']
        elif f'{self.args.dataset}_val' in dataset_transforms.keys() and not train:
            return dataset_transforms[f'{self.args.dataset}_val']
        raise ValueError(f"Transform for dataset {self.args.dataset} not found.")

    def get_dataloader(self, train: int = True, distributed: bool = False, rank: int = 0, world_size: int = 0) -> DataLoader:
        dataset = self.get_dataset(train=train)
        return DataLoader(
            dataset, 
            batch_size=self.args.batch_size, 
            shuffle=True if not distributed else False,
            num_workers=self.args.num_workers,
            pin_memory=True,
            persistent_workers=True,
            collate_fn=collate_fns[self.args.dataset],
            drop_last=True,
            sampler=DistributedSampler(dataset) if distributed else None
        )
    
    def get_optimizer(self) -> torch.optim.Optimizer:
        params = split_weights(self.model)
        if self.args.optimizer == 'sgd':
            return optimizers[self.args.optimizer](
                params,
                lr=self.args.lr,
                weight_decay=self.args.l2, 
                momentum=self.args.momentum,
                nesterov=self.args.nesterov.lower() == 'true'
            )
        return optimizers[self.args.optimizer](
            params, 
            lr=self.args.lr, 
            weight_decay=self.args.l2
        )
    
    def get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
        scheduler = None
        if self.args.scheduler == 'None' or self.args.scheduler is None:
            return scheduler
        elif self.args.scheduler == 'cosine':
            scheduler = schedulers[self.args.scheduler](
                optimizer=optimizer,
                T_max=self.args.epochs,
                # T_max=100,
                eta_min=0.00001
            )
        else:
            return ValueError(f"Scheduler {self.args.scheduler} not found.")
        if self.args.warmup > 0:
            return GradualWarmupScheduler(
                optimizer=optimizer,
                multiplier=1,
                total_epoch=self.args.warmup,
                after_scheduler=scheduler
            )
        
        return scheduler
    
    def get_criterion(self) -> torch.nn.Module:
        return losses[self.args.loss](label_smoothing=self.args.label_smoothing)

    def get_lr(self) -> float:
        return self.args.lr

    def validate(self) -> bool:
        return self.args.validate
    
    def resume(self) -> bool:
        return self.args.resume.lower() == 'true'

    def get_model_dir(self) -> str:
        return "./SNN/"
    
    def _load_model(self) -> torch.nn.Module:
        if self.is_snn:
            return self._load_spiking_model()
        
    def _load_spiking_model(self) -> torch.nn.Module:
        snn_config = NeuronConfig(
            timesteps=self.args.seq_length,
            poly_degree=self.args.poly_degree
        )

        model = self.get_class_from_string(self.args.model, snn_models.classification_all, [snn_models.classification])

        return model( 
            num_classes=dataset_classes[self.args.dataset], 
            encoder=encoders[self.args.encoder](seq_length=self.args.seq_length, config=snn_config),
            decoder=decoders[self.args.decoder],
            config=snn_config, 
            method=self.args.sg
        )

    def load_weights(self) -> str:
        return self.args.load_weights

    def save(self, path: str) -> None:
        with open(path, 'w') as f:
            json.dump(self.args.__dict__, f, indent=4)

    def get_class_from_string(self, class_name: str, class_list: list, class_import: list) -> type:
        try:
            for c in class_list:
                if class_name.lower() == c.lower():
                    for import_module in class_import:
                        if c in import_module.__all__:
                            module = importlib.import_module(import_module.__name__)
                            return getattr(module, c)
            else:
                raise ValueError(f"Class '{class_name}' not found in {class_list}.\nMake sure you define the class in the correct module and __all__ variable.")
        except AttributeError:
            raise ValueError(f"Class {class_name} is not defined. Available classes include {class_list}.\nMake sure you define the class in the correct module and __all__ variable.")
