from AbstractModels.Model import Model
from AbstractModels.SpikingModel import SpikingModel
from AbstractModels.Config import Config

import torch
from torch.utils.data import DataLoader
from torch import nn
from torch import autocast
import torch.distributed as dist


from tqdm.auto import tqdm
from datetime import datetime
import pandas as pd

from time import time
import inspect

DATE_FORMAT = "(%Y-%m-%d)_(%H-%M-%S)"

class Trainer:
    def __init__(self, model: Model | SpikingModel):
        super(Trainer, self).__init__()
        self.df: pd.DataFrame = pd.DataFrame(
            columns=['Model', 'Epoch', 'Training Loss', 'Training Accuracy', 'Validation Loss', 'Validation Accuracy', 'Training Spike-Rate', 'Validation Spike-Rate', 'Training Time', 'Validation Time']
        )

        self.model:  Model | SpikingModel = model
        self.DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def fit(
        self, 
        epochs: int, 
        train_loader: DataLoader,
        val_loader: DataLoader,
        optimizer: torch.optim.Optimizer,
        criterion: torch.nn.modules.loss._Loss,
        model_dir: str,
        config: Config,
        scheduler: torch.optim.lr_scheduler._LRScheduler = None,
        gpu_id: int = 0,
        distributed_training: bool = False
    ) -> None:
        r'''
        Trains the model and saves the model and metrics to disk.
        If training is stopped due to a KeyboardInterrupt, the metrics and model are not saved to disk.
        '''
        self.config: Config = config
        
        if config.get_name() is not None:
            model_name = f'{config.get_name()}_{datetime.now().strftime(DATE_FORMAT)}'
        else:
            model_name = f'{self.model.name}_{datetime.now().strftime(DATE_FORMAT)}'
        
        self.distributed_training = distributed_training

        self.rank = gpu_id
        if distributed_training:
            self.DEVICE = torch.device(f'cuda:{gpu_id}')

        self.model.to(self.DEVICE)

        if isinstance(criterion, nn.Module):
            criterion.to(self.DEVICE)
        
        start_epoch: int = 0

        if config.load_weights() is not None:
            scheduler, start_epoch, df = self.model.load(model_dir, config.load_weights(), optimizer, scheduler, config.resume())
            if df is not None:
                self.df = df

        if config.resume() and scheduler is None:
            for param_group in optimizer.param_groups:
                param_group['lr'] = config.get_lr()

        if self.rank == 0:
            self.model.mkfolders(model_dir, model_name)
        
        self.get_spike_rate() # get spike rate to reset the counter, discard the value

        self._train(
            start_epoch=start_epoch,
            epochs=epochs,
            checkpoint_period=config.get_checkpoint_period(),
            train_loader=train_loader, 
            val_loader=val_loader, 
            optimizer=optimizer, 
            criterion=criterion,
            model_dir=model_dir,
            model_name=model_name,
            scheduler=scheduler
        )

        if self.rank == 0:
            self.model.save(config, model_dir, model_name, self.df, optimizer, scheduler)

    def _train(
        self,
        start_epoch: int,
        epochs: int,
        checkpoint_period: int,
        train_loader: DataLoader,
        val_loader: DataLoader,
        optimizer: torch.optim.Optimizer, 
        criterion: torch.nn.modules.loss._Loss,
        model_dir: str,
        model_name: str,
        scheduler: torch.optim.lr_scheduler._LRScheduler = None
    ) -> None:
        r'''
        Trains and validates the model for the specified number of epochs.
        '''
        
        tqdm.write(f'Training {self.model.name} on {self.DEVICE}\n')
        tqdm.write(f'Epochs: {epochs}')
        tqdm.write(f'* Dataset size: {len(train_loader.dataset)}')
        tqdm.write(f'* Batch Size: {train_loader.batch_size}\n')
        tqdm.write(f'Loss Function: {criterion}')
        tqdm.write(f'Optimizer: {optimizer.__class__}')
        tqdm.write(f'Scheduler: {scheduler.__class__}\n')
            
        for epoch in tqdm(range(start_epoch, epochs), desc="Train", leave=True):
            if self.distributed_training:
                train_loader.sampler.set_epoch(epoch)
                val_loader.sampler.set_epoch(epoch)

            train_loss, train_accuracy, train_spike_rate, train_time = self._step(
                train_loader=train_loader, 
                optimizer=optimizer, 
                criterion=criterion,
                epoch=epoch
            )

            self.barrier()

            # print(f"[RANK {self.rank}] Starting validation...")
            val_loss, val_accuracy, val_spike_rate, val_time = self._validate(
                val_loader=val_loader, 
                criterion=criterion
            )
            # print(f"[RANK {self.rank}] Finished validation.")

            self.record_metrics(
                epoch=epoch, 
                train_loss=train_loss, 
                train_accuracy=train_accuracy, 
                val_loss=val_loss, 
                val_accuracy=val_accuracy,
                train_spike_rate=train_spike_rate,
                val_spike_rate=val_spike_rate,
                train_time=train_time,
                val_time=val_time
            )

            self.print_step(
                epoch=epoch,
                epochs=epochs,
                lr=optimizer.param_groups[0]['lr'],
            )
            
            self.scheduler_step(scheduler, val_loss)

            if self.rank == 0 and checkpoint_period > 0 and (epoch + 1) % checkpoint_period == 0 and self.rank == 0 and epoch != 0 and epoch != epochs - 1:
                tqdm.write(f'\nSaving checkpoint for epoch {epoch + 1}\n')
                self.model.save_checkpoint(model_dir, model_name, optimizer, scheduler, epoch + 1, self.df)

            self.barrier()
            
        tqdm.write('Training complete.')
    
    def _step(
        self,
        train_loader: DataLoader,
        optimizer: torch.optim.Optimizer,
        criterion: torch.nn.modules.loss._Loss,
        epoch: int = 0
    ) -> tuple[float, float, float, float]:
        '''
        Trains the model for one epoch.

        Args:
            train_loader (DataLoader): Training data.
            optimizer (torch.optim.Optimizer): Optimizer.
            criterion (torch.nn.modules.loss._Loss): Loss function.
        '''

        self.model.train()

        correct: int = 0
        total: int = 0
        train_loss: float = 0
        start: float = time()

        scalar: torch.amp.GradScaler = torch.amp.GradScaler(self.DEVICE.type)
    
        for _, (data, target) in enumerate(tqdm(train_loader, desc=f"T. Step ID: {self.rank}", leave=False)):
            data, target = data.float().to(self.DEVICE), target.to(self.DEVICE)

            with autocast(device_type=self.DEVICE.type, dtype=torch.float16, enabled=self.DEVICE.type == 'cuda'):
                output: torch.Tensor = self.model(data)
                loss = criterion(output, target)

                train_loss += loss.item()
                
                scalar.scale(loss).backward()
                scalar.step(optimizer)
                scalar.update()
                    
                optimizer.zero_grad()

            correct += (torch.argmax(output, dim=1) == target).sum().item()
            total += target.size(0)

        return (
            train_loss / len(train_loader),
            correct/total * 100, 
            self.get_spike_rate(),
            time() - start
        )

    @torch.no_grad()
    def _validate(
        self,
        val_loader: DataLoader,
        criterion: torch.nn.modules.loss._Loss
    ) -> tuple[float, float, float, float]:
        '''
        Validates the model on the validation set.

        For internal use.

        Args:
            val_loader (DataLoader): Validation data.
            criterion (torch.nn.modules.loss._Loss): Loss function.

        Returns:
            tuple: Validation loss and accuracy.
        '''

        self.model.eval()
        correct: int = 0
        total: int = 0
        val_loss: float = 0
        start: float = time()
        for data, target in tqdm(val_loader, desc="V. Step", leave=False):
            data, target = data.float().to(self.DEVICE), target.to(self.DEVICE)
            output = self.model(data)

            loss = criterion(output, target)
            val_loss += loss.item()

            if output.ndim == 3:
                output = torch.mean(output, dim=0)
    
            correct += (torch.argmax(output, dim=1) == target).sum().item()
            total += target.size(0)
        
        return (
            val_loss / len(val_loader), 
            correct/total * 100, 
            self.get_spike_rate(),
            time() - start
        )
    
    @torch.no_grad()
    def validate(
        self,
        val_loader: DataLoader,
        criterion: torch.nn.modules.loss._Loss,
        config
    ) -> tuple[float, float, float, float]:
        '''
        Validates the model on the validation set.

        For external use.

        Args:
            val_loader (DataLoader): Validation data.
            criterion (torch.nn.modules.loss._Loss): Loss function.
            config (Config): Config object.
        '''
        self.model.to(self.DEVICE)
        criterion.to(self.DEVICE)

        if config.load_weights() is not None:
            self.model.load(config.get_model_dir(), config.load_weights(), None, None)

        loss, accuracy, spike_rate, time = self._validate(val_loader, criterion)

        results = (
            f'Validation Loss: {loss : .6f}, Validation Accuracy: {accuracy : .2f}'
        )

        if spike_rate is not None:
            results += f', Validation Spike-Rate: {spike_rate : .2f}%'

        tqdm.write(results)
        
        return loss, accuracy, spike_rate, time

    def scheduler_step(self, scheduler: torch.optim.lr_scheduler._LRScheduler, val_loss: float) -> None:
        '''
        Step the scheduler.
        '''
        if scheduler is not None:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(val_loss)
            else:
                scheduler.step()
    
    def barrier(self) -> None:
        '''
        Barrier for distributed training.
        '''
        if self.distributed_training:
            dist.barrier()

    def print_step(
        self,
        epoch: int,
        epochs: int,
        lr: float,
    ) -> None:
        '''
        Prints the metrics for the current epoch.
        '''
        if self.rank == 0:
            train_loss = self.df['Training Loss'].iloc[-1]
            train_accuracy = self.df['Training Accuracy'].iloc[-1]
            
            val_loss = self.df['Validation Loss'].iloc[-1]
            val_accuracy = self.df['Validation Accuracy'].iloc[-1]

            train_spike_rate = self.df['Training Spike-Rate'].iloc[-1]
            val_spike_rate = self.df['Validation Spike-Rate'].iloc[-1]

            results = (
                f'Epoch [{epoch+1 : ^3d}/{epochs : 2d}]: '
                f'T/V Loss:{train_loss : .6f} /{val_loss : .6f}, '
                f'T/V Acc:{train_accuracy : .2f}% /{val_accuracy : .2f}%, '
                f'LR:{lr : .6f}'
            )
            if train_spike_rate is not None and val_spike_rate is not None:
                results += f', T/V SR:{train_spike_rate: .2f}% /{val_spike_rate: .2f}%'
            
            tqdm.write(results)

    def get_spike_rate(self) -> float:
        spike_rate = None
        if isinstance(self.model, SpikingModel):
            spike_rate = self.model.get_avg_spike_rate()
            self.model.reset_total_spikes()
        return spike_rate

    def record_metrics(
        self,
        epoch: int,
        train_loss: float,
        train_accuracy: float,
        val_loss: float,
        val_accuracy: float,
        train_spike_rate: float,
        val_spike_rate: float,
        train_time: float,
        val_time: float
    ) -> None:
        """
        Synchronize and record metrics across GPUs for distributed training.
        """
        if self.distributed_training:
            train_loss = self.sync_metric(train_loss, mode='average')
            train_accuracy = self.sync_metric(train_accuracy, mode='average')
            val_loss = self.sync_metric(val_loss, mode='average')
            val_accuracy = self.sync_metric(val_accuracy, mode='average')
            train_time = self.sync_metric(train_time, mode='sum')
            val_time = self.sync_metric(val_time, mode='sum')
            
            if self.config.is_snn:
                train_spike_rate = self.sync_metric(train_spike_rate, mode='average')
                val_spike_rate = self.sync_metric(val_spike_rate, mode='average')

        # Only log metrics on rank 0
        if self.rank == 0:
            self.df = pd.concat([
                self.df if not self.df.empty else None,
                pd.DataFrame(
                    data=[[
                        self.__class__.__name__,
                        epoch + 1,
                        train_loss,
                        train_accuracy,
                        val_loss,
                        val_accuracy,
                        train_spike_rate,
                        val_spike_rate,
                        train_time,
                        val_time
                    ]],
                    columns=[
                        'Model',
                        'Epoch',
                        'Training Loss',
                        'Training Accuracy',
                        'Validation Loss',
                        'Validation Accuracy',
                        'Training Spike-Rate',
                        'Validation Spike-Rate',
                        'Training Time',
                        'Validation Time'
                    ])
            ])

    def has_epoch_arg(self, func):
        sig = inspect.signature(func)
        return "epoch" in sig.parameters
    
    @staticmethod
    def sync_metric(value: float, mode: str = 'average') -> float:
        """
        Synchronizes a metric across all GPUs.

        Args:
            value (float): The local metric value.
            mode (str): The synchronization mode ('average').

        Returns:
            float: The synchronized metric value.
        """
        tensor = torch.tensor(value, dtype=torch.float32, device='cuda')
        dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
        if mode == 'average':
            tensor /= dist.get_world_size()
        return tensor.item()