import torch
from torch.utils.data import DataLoader
from torch import nn
from torch import autocast

from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from datetime import datetime
import pandas as pd
import os
import shutil

from time import time, sleep

from AbstractModels.util import mixup_criterion, cutmix

OUTPUT_DIR = './output/'
LOG_DIR = './logs/'
DATE_FORMAT = "(%Y-%m-%d)_(%H-%M-%S)"

class ConvolutionModel(nn.Module):
    def __init__(self, num_classes: int, quantized: bool = False, is_snn: bool = False):
        super(ConvolutionModel, self).__init__()
        self.df: pd.DataFrame = pd.DataFrame(
            columns=['Model', 'Epoch', 'Training Loss', 'Training Accuracy', 'Validation Loss', 'Validation Accuracy', 'Training Spike-Rate', 'Validation Spike-Rate', 'Training Time', 'Validation Time']
        )
        self.writer: SummaryWriter = None
        self.name: str = None

        self.num_classes: int = num_classes
        self.quantized: bool = quantized
        self.is_snn: bool = is_snn

        self.DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def forward(self, _: torch.Tensor):
        pass

    def summary(self):
        pass
    
    def fit(
        self, 
        epochs: int, 
        train_loader: DataLoader,
        val_loader: DataLoader,
        optimizer: torch.optim.Optimizer,
        criterion: torch.nn.modules.loss._Loss,
        model_dir: str,
        config,
        scheduler: torch.optim.lr_scheduler._LRScheduler = None,
        gradient_steps: int = 1
    ) -> None:
        r'''
        Trains the model and saves the model and metrics to disk.
        If training is stopped due to a KeyboardInterrupt, the metrics and model are not saved to disk.
        '''
        self.config = config
        try:
            if config.get_name() is not None:
                model_name = f'{config.get_name()}_{datetime.now().strftime(DATE_FORMAT)}'
            else:
                model_name = f'{self.name}_{datetime.now().strftime(DATE_FORMAT)}'
            
            self.to(self.DEVICE)
            # check if criterion is a loss function or a callable
            if isinstance(criterion, nn.Module):
                criterion.to(self.DEVICE)

            start_epoch: int = 0

            if config.load_weights() is not None:
                scheduler, start_epoch = self.load(model_dir, config.load_weights(), optimizer, scheduler, config.resume())

            if config.resume() and scheduler is None:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = config.get_lr()

            self.mkfolders(model_dir, model_name)
            self.writer = SummaryWriter(f'{model_dir}{OUTPUT_DIR}{model_name}/{LOG_DIR}{model_name}')
            
            self.get_spike_rate(rate=1) # get spike rate to reset the counter, discard the value

            self._train(
                start_epoch=start_epoch,
                epochs=epochs,
                checkpoint_period=config.get_checkpoint_period(),
                train_loader=train_loader, 
                val_loader=val_loader, 
                optimizer=optimizer, 
                criterion=criterion,
                model_dir=model_dir,
                model_name=model_name,
                scheduler=scheduler,
                gradient_steps=gradient_steps
            )

            self.save(config, model_dir, model_name, optimizer, scheduler)
        except KeyboardInterrupt:
            # Stop Tensorboard log from being generated if training is stopped by user.
            print('Keyboard Interrupt: Training stopped by user.')
            if os.path.exists(f'{model_dir}{OUTPUT_DIR}{model_name}') and len(os.listdir(f'{model_dir}{OUTPUT_DIR}{model_name}')) == 1:
                print(f'Removing Log: {OUTPUT_DIR}{model_name}')
                shutil.rmtree(f'{model_dir}{OUTPUT_DIR}{model_name}')


    def _train(
        self,
        start_epoch: int,
        epochs: int,
        checkpoint_period: int,
        train_loader: DataLoader,
        val_loader: DataLoader,
        optimizer: torch.optim.Optimizer, 
        criterion: torch.nn.modules.loss._Loss,
        model_dir: str,
        model_name: str,
        scheduler: torch.optim.lr_scheduler._LRScheduler = None,
        gradient_steps: int = 1
    ) -> None:
        r'''
        Trains and validates the model for the specified number of epochs.
        '''
        
        tqdm.write(f'Training {self.name} on {self.DEVICE}\n')
        tqdm.write(f'Epochs: {epochs}')
        tqdm.write(f'* Dataset size: {len(train_loader.dataset)}')
        tqdm.write(f'* Batch Size: {train_loader.batch_size}\n')
        tqdm.write(f'Loss Function: {criterion}')
        tqdm.write(f'Optimizer: {optimizer.__class__}')
        tqdm.write(f'Scheduler: {scheduler.__class__}\n')
            
        for epoch in tqdm(range(start_epoch, epochs), desc="Train", leave=True):
            if epoch % 25 == 0 and epoch != 0:
                tqdm.write('\nSleeping for 1 minute to prevent overheating.\n')
                sleep(60)
            
            if epoch == 251:
                exit(1)

            train_loss, train_accuracy, train_spike_rate, train_time = self._step(
                train_loader=train_loader, 
                optimizer=optimizer, 
                criterion=criterion, 
                gradient_steps=gradient_steps
            )

            val_loss, val_accuracy, val_spike_rate, val_time = self._validate(
                val_loader=val_loader, 
                criterion=criterion
            )

            self.record_metrics(
                epoch=epoch, 
                train_loss=train_loss, 
                train_accuracy=train_accuracy, 
                val_loss=val_loss, 
                val_accuracy=val_accuracy,
                train_spike_rate=train_spike_rate,
                val_spike_rate=val_spike_rate,
                train_time=train_time,
                val_time=val_time
            )

            self.print_step(
                epoch=epoch,
                epochs=epochs,
                train_loss=train_loss,
                train_accuracy=train_accuracy,
                val_loss=val_loss,
                val_accuracy=val_accuracy,
                lr=optimizer.param_groups[0]['lr'],
                train_spike_rate=train_spike_rate,
                val_spike_rate=val_spike_rate
            )
            
            self.scheduler_step(scheduler)

            if checkpoint_period > 0 and (epoch + 1) % checkpoint_period == 0 and epoch != 0 and epoch != epochs - 1:
                tqdm.write(f'\nSaving checkpoint for epoch {epoch + 1}\n')
                self.save_checkpoint(model_dir, model_name, optimizer, scheduler, epoch + 1)
            
        tqdm.write('Training complete.')
    
    def _step(
        self,
        train_loader: DataLoader,
        optimizer: torch.optim.Optimizer,
        criterion: torch.nn.modules.loss._Loss,
        gradient_steps: int = 1
    ) -> tuple[float, float, float, float]:
        '''
        Trains the model for one epoch.

        Args:
            train_loader (DataLoader): Training data.
            optimizer (torch.optim.Optimizer): Optimizer.
            criterion (torch.nn.modules.loss._Loss): Loss function.
            gradient_steps (int, optional): Number of gradient steps to accumulate before updating the weights. Defaults to 1.
        '''

        self.train()

        correct: int = 0
        total: int = 0
        train_loss: float = 0
        start: float = time()

        scalar: torch.amp.GradScaler = torch.amp.GradScaler(self.DEVICE.type)

        for _, (data, target) in enumerate(tqdm(train_loader, desc="T. Step", leave=False)):
            data, target = data.float().to(self.DEVICE), target.to(self.DEVICE)

            with autocast(device_type=self.DEVICE.type, dtype=torch.float16, enabled=self.DEVICE.type == 'cuda'):
                if self.config.cutmix():
                    data, target_a, target_b, lam = cutmix(data, target)
                    output: torch.Tensor = self.forward(data)
                    loss = mixup_criterion(criterion, output, target_a, target_b, lam)
                else:
                    output: torch.Tensor = self.forward(data)
                    loss = criterion(output, target)

                train_loss += loss.item()
                scalar.scale(loss).backward()

                scalar.step(optimizer)
                scalar.update()
                optimizer.zero_grad()

            if output.ndim == 3:
                output = torch.mean(output, dim=0)

            correct += (torch.argmax(output, dim=1) == target).sum().item()
            total += target.size(0)

        return (
            train_loss / len(train_loader),
            correct/total * 100, 
            self.get_spike_rate(rate=total), 
            time() - start
        )

    @torch.no_grad()
    def _validate(
        self,
        val_loader: DataLoader,
        criterion: torch.nn.modules.loss._Loss
    ) -> tuple[float, float, float, float]:
        '''
        Validates the model on the validation set.

        For internal use.

        Args:
            val_loader (DataLoader): Validation data.
            criterion (torch.nn.modules.loss._Loss): Loss function.

        Returns:
            tuple: Validation loss and accuracy.
        '''

        self.eval()
        correct: int = 0
        total: int = 0
        val_loss: float = 0
        start: float = time()
        for data, target in tqdm(val_loader, desc="V. Step", leave=False):
            data, target = data.float().to(self.DEVICE), target.to(self.DEVICE)
            output = self.forward(data)

            loss = criterion(output, target)
            val_loss += loss.item()

            if output.ndim == 3:
                output = torch.mean(output, dim=0)
    
            correct += (torch.argmax(output, dim=1) == target).sum().item()
            total += target.size(0)
        
        return (
            val_loss / len(val_loader), 
            correct/total * 100, 
            self.get_spike_rate(rate=total),
            time() - start
        )
    
    @torch.no_grad()
    def validate(
        self,
        val_loader: DataLoader,
        criterion: torch.nn.modules.loss._Loss,
        config
    ) -> tuple[float, float, float, float]:
        '''
        Validates the model on the validation set.

        For external use.

        Args:
            val_loader (DataLoader): Validation data.
            criterion (torch.nn.modules.loss._Loss): Loss function.
            config (Config): Config object.
        '''
        self.to(self.DEVICE)
        criterion.to(self.DEVICE)

        if config.load_weights() is not None:
            self.load(config.get_model_dir(), config.load_weights(), None, None)

        loss, accuracy, spike_rate, time = self._validate(val_loader, criterion)

        results = (
            f'Validation Loss: {loss : .6f}, Validation Accuracy: {accuracy : .2f}'
        )

        if spike_rate is not None:
            results += f', Validation Spike-Rate: {spike_rate : .2f}'

        tqdm.write(results)
        
        return loss, accuracy, spike_rate, time

    def scheduler_step(self, scheduler: torch.optim.lr_scheduler._LRScheduler) -> None:
        '''
        Step the scheduler.
        '''
        if scheduler is not None:
            scheduler.step()
    
    def print_step(
        self,
        epoch: int,
        epochs: int,
        train_loss: float,
        train_accuracy: float,
        val_loss: float,
        val_accuracy: float,
        lr: float,
        train_spike_rate: int = None,
        val_spike_rate: int = None
    ) -> None:
        '''
        Prints the metrics for the current epoch.
        '''
        results = (
            f'Epoch [{epoch+1 : ^3d}/{epochs : 2d}]: '
            f'T. Loss: {train_loss : .6f}, '
            f'V. Loss: {val_loss : .6f}, '
            f'T. Acc: {train_accuracy : .2f}%, '
            f'V. Acc: {val_accuracy : .2f}%, '
            f'LR: {lr : .6f}'
        )
        if train_spike_rate is not None:
            results += f', T. SR: {train_spike_rate: .2f}'
        if val_spike_rate is not None:
            results += f', V. SR: {val_spike_rate: .2f}'
        
        tqdm.write(results)

    def get_spike_rate(self, rate: int) -> float:
        spike_rate = None
        if self.is_snn:
            spike_rate = self.get_total_spikes() / rate
            self.reset_total_spikes()
        return spike_rate

    def record_metrics(
        self,
        epoch: int,
        train_loss: float,
        train_accuracy: float,
        val_loss: float,
        val_accuracy: float,
        train_spike_rate: float,
        val_spike_rate: float,
        train_time: float,
        val_time: float
    ) -> None:
    
        self.writer.add_scalar('Loss/train', train_loss, epoch)
        self.writer.add_scalar('Loss/validation', val_loss, epoch)
        self.writer.add_scalar('Accuracy/train', train_accuracy, epoch)
        self.writer.add_scalar('Accuracy/validation', val_accuracy, epoch)
        self.writer.add_scalar('Time/train', train_time, epoch)
        self.writer.add_scalar('Time/validation', val_time, epoch)

        if train_spike_rate is not None:
            self.writer.add_scalar('SpikeRate/train', train_spike_rate, epoch)
        if val_spike_rate is not None:
            self.writer.add_scalar('SpikeRate/validation', val_spike_rate, epoch)

        self.df = pd.concat([
            self.df if not self.df.empty else None, 
            pd.DataFrame(
                data=[[
                    self.__class__.__name__, 
                    epoch+1, 
                    train_loss,
                    train_accuracy,
                    val_loss, 
                    val_accuracy,
                    train_spike_rate,
                    val_spike_rate,
                    train_time,
                    val_time
                ]],
                columns=[
                    'Model', 
                    'Epoch', 
                    'Training Loss',
                    'Training Accuracy',
                    'Validation Loss', 
                    'Validation Accuracy',
                    'Training Spike-Rate',
                    'Validation Spike-Rate',
                    'Training Time',
                    'Validation Time'
                ])
        ])

    def mkfolders(
        self, 
        model_dir: str, 
        model_name: str
    ) -> None:
        '''
        Creates the output and log folders for the model.

        Args:
            model_dir (str): Directory to save the model.
            model_name (str): Name of the model.
        '''
        if not os.path.exists(f'{model_dir}{OUTPUT_DIR}'):
            os.makedirs(f'{model_dir}{OUTPUT_DIR}')

        if not os.path.exists(f'{model_dir}{OUTPUT_DIR}{model_name}'):
            os.makedirs(f'{model_dir}{OUTPUT_DIR}{model_name}')

        if not os.path.exists(f'{model_dir}{OUTPUT_DIR}{model_name}/{LOG_DIR}{model_name}'):
            os.makedirs(f'{model_dir}{OUTPUT_DIR}{model_name}/{LOG_DIR}{model_name}')

    def save(
        self, 
        config, 
        model_dir: str, 
        model_name: str,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler._LRScheduler | None
    ) -> None:
        '''
        Saves the model and metrics to disk.

        Args:
            config (Config): Config object.
            model_dir (str): Directory to save the model.
            model_name (str): Name of the model.
        '''
        if self.quantized:
            model_name += '_quantized'

        checkpoint = {
            "model": self.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler
        }
        self.remove_checkpoints(model_dir, model_name)

        self._save(checkpoint, f"{model_dir}{OUTPUT_DIR}{model_name}/{model_name}")
        self.writer.close()
        self.df.to_csv(f'{model_dir}{OUTPUT_DIR}{model_name}/{model_name}.csv', index=False)
        config.save(f'{model_dir}{OUTPUT_DIR}{model_name}/{model_name}.json')
        if config.is_snn:
            self.save_params(f'{model_dir}{OUTPUT_DIR}{model_name}/params.json')
            

    def remove_checkpoints(self, model_dir: str, model_name: str) -> None:
        '''
        Removes the checkpoint file from disk.

        Args:
            model_dir (str): Directory to save the model.
            model_name (str): Name of the model.
        '''
        try:
            for file in os.listdir(f'{model_dir}{OUTPUT_DIR}{model_name}'):
                if file.endswith('.pth') and 'checkpoint' in file:
                    os.remove(f'{model_dir}{OUTPUT_DIR}{model_name}/{file}')
        except FileNotFoundError:
            print('No checkpoints found.')

    def save_checkpoint(
        self,
        model_dir: str,
        model_name: str,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler._LRScheduler | None,
        epoch: int
    ) -> None:
        checkpoint = {
            "model": self.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler
        }

        self._save(checkpoint, f"{model_dir}{OUTPUT_DIR}{model_name}/{model_name}_checkpoint_{epoch}")
        self.df.to_csv(f'{model_dir}{OUTPUT_DIR}{model_name}/{model_name}.csv', index=False)

    def _save(
        self,
        checkpoint: dict[str, nn.Module],
        path: str
    ) -> None:
        torch.save(checkpoint, f"{path}.pth")

    def load(
        self, 
        model_dir: str, 
        model_name: str,
        optimizer: torch.optim.Optimizer | None,
        scheduler: torch.optim.lr_scheduler._LRScheduler | None,
        resume: bool = False
    ) -> tuple[torch.optim.lr_scheduler._LRScheduler | None, int]:
        '''
        Loads a model from disk.

        Args:
            model_dir (str): Directory to save the model.
            model_name (str): Name of the model.
            optimizer (torch.optim.Optimizer): Optimizer.
            scheduler (torch.optim.lr_scheduler._LRScheduler): Scheduler.
        '''
        path = f"{model_dir}{OUTPUT_DIR}{model_name}/{model_name}"

        checkpoint = self._load(path)
        self.load_state_dict(checkpoint['model'])

        if optimizer is not None and resume:
            optimizer.load_state_dict(checkpoint['optimizer'])

        if scheduler is not None and resume:
            scheduler = checkpoint['scheduler']
            scheduler.last_epoch = checkpoint['scheduler'].last_epoch
            if optimizer is None:
                raise ValueError('Optimizer must be provided if scheduler is provided.')
            scheduler.optimizer = optimizer

        start_epoch: int = 0

        if os.path.exists(f'{model_dir}{OUTPUT_DIR}{model_name}/{model_name}.csv') and resume:
            self.df = pd.read_csv(f'{model_dir}{OUTPUT_DIR}{model_name}/{model_name}.csv')
            start_epoch = self.df['Epoch'].iloc[-1]
        
        return scheduler, start_epoch

    def load_by_path(
        self, 
        path: str,
        optimizer: torch.optim.Optimizer | None,
        scheduler: torch.optim.lr_scheduler._LRScheduler | None
    ) -> torch.optim.lr_scheduler._LRScheduler | None:
        '''
        Loads a model from disk via an absolute path to the weight file.

        Args:
            config (Config): Config object.
            model_name (str): Name of the model.
        '''
        checkpoint = self._load(path)

        self.load_state_dict(checkpoint['model'])

        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer'])

        if scheduler is not None:
            scheduler = checkpoint['scheduler']

        return scheduler

    def _load(self, path: str) -> dict[str, nn.Module]:
        '''
        Loads a module from disk via a relative path to the weight file.

        Args:
            path (str): Path to the weight file.
        '''
        try:
            checkpoint = torch.load(f"{path}.pth", map_location=self.DEVICE)
            print(f'Loaded from {path}.pth')
            return checkpoint
        except Exception as e:
            print(f'Could not find {path}.pth')
            print(e)
            exit(1)