import json
import os
import sys
import time
from math import cos, pi
import numpy as np
import torch
import wandb
from barbar import Bar

from optimizers import optimizers
from runner import Runner as BaseRunner
from utilities import Utilities as Utils
from strategies import strategies
from metrics import metrics


class Runner(BaseRunner):

    def __init__(self, config, debug_mode=False):
        super().__init__(config=config, debug_mode=debug_mode)
        # IMP Sparsity settings (we remove sparsity 0)
        self.warm_up_fn = None
        self.n_warmup_epochs = None
        self.trainIteration = None
        self.momentum_warm_up_fn = None
        self.no_momentum_epochs = None

    def get_missing_config(self):
        missing_config_keys = ['momentum',
                               'dampening',
                               'weight_decay_ord',
                               'nesterov',
                               'nepochs']
        additional_dict = {
            'last_training_lr':self.reference_run.summary['trained.learning_rate'],
            'optimizer': 'SGD',
            'trained.test.accuracy':self.reference_run.summary['trained.test']['accuracy']
        }
        for key in missing_config_keys:
            if key not in self.config:
                self.config[key] = self.reference_run.config[key]
        self.config.update(additional_dict)

        self.trained_test_accuracy = additional_dict['trained.test.accuracy']

    def define_optimizer_scheduler(self):
        # Define the optimizer using the parameters from the reference run
        self.optimizer = optimizers.SGD(params=self.model.parameters(), lr=self.config['last_training_lr'], # lr is adjusted by rewinding routine
                                        momentum=self.config['momentum'], dampening=self.config['dampening'],
                                        weight_decay=self.config['weight_decay'],
                                        weight_decay_ord=self.config['weight_decay_ord'],
                                        nesterov=self.config['nesterov'], global_constraint=None)

    def restore_model(self, from_initial=False) -> None:
        outputStr = 'initial' if from_initial else 'checkpoint'
        sys.stdout.write(f"Restoring {outputStr} model from {self.checkpoint_file if not from_initial else self.trainedModelFile}.\n")
        self.model = self.get_model(load_checkpoint=True, load_initial=(from_initial is True))

    def log(self, runTime, desired_sparsity=None, last_epoch=False):
        loggingDict = self.get_metrics()

        loggingDict.update({'epoch_run_time': runTime})
        test_accuracy_delta = self.test_accuracy.result() - self.trained_test_accuracy  # mostly negative
        # Add a specific plot for each sparsity to circumvent the problem of not being able to filter in wandb

        if last_epoch:
            wandb.log(
                dict(finetune=loggingDict,
                     pruned=self.after_pruning_metrics[desired_sparsity],   # Metrics directly after pruning
                     desired_sparsity=desired_sparsity,
                     total_finetune_time=self.totalFinetuneTime,
                     test_accuracy_delta=test_accuracy_delta,
                     reference_n_epochs_finetune=self.config.n_epochs_finetune,
                     ),
            )
            # Dump sparsity distribution to json and upload
            sparsity_distribution = metrics.per_layer_sparsity(model=self.model)
            fPath = os.path.join(wandb.run.dir, f'sparsity_distribution_{desired_sparsity}.json')
            with open(fPath, 'w') as fp:
                json.dump(sparsity_distribution, fp)
            wandb.save(fPath)
        else:
            wandb.log(
                dict(finetune=loggingDict,
                     ),
            )

    def train_epoch(self, data='train', evaluation_mode=False, epoch=None, include_cls_data=None):
        if data == 'train':
            loader = self.trainLoader
            mean_loss, mean_accuracy, mean_k_accuracy = self.train_loss, self.train_accuracy, self.train_k_accuracy
        elif data == 'test':
            loader = self.testLoader
            mean_loss, mean_accuracy, mean_k_accuracy = self.test_loss, self.test_accuracy, self.test_k_accuracy
        assert not (data == ' test' and evaluation_mode == False), "Can't train on test set."
        if evaluation_mode:
            with torch.no_grad():
                sys.stdout.write(f"Evaluation of {data} data:\n")
                for x_input, y_target in Bar(loader):
                    x_input, y_target = x_input.to(self.device), y_target.to(self.device)  # Move to CUDA if possible
                    output = self.model.eval()(x_input)
                    loss = self.loss_criterion(output, y_target)
                    mean_loss(loss.item(), len(y_target))
                    mean_accuracy(Utils.categorical_accuracy(y_true=y_target, output=output), len(y_target))
                    mean_k_accuracy(Utils.categorical_accuracy(y_true=y_target, output=output, topk=self.k_accuracy),
                                    len(y_target))
        else:
            sys.stdout.write(f"Training:\n")
            for x_input, y_target in Bar(self.trainLoader):
                if self.warm_up_fn is not None or self.momentum_warm_up_fn is not None:
                    if self.warm_up_fn:
                        iteration_lr = self.warm_up_fn(self.trainIteration)
                        for param_group in self.optimizer.param_groups:
                            param_group['lr'] = iteration_lr
                    if self.momentum_warm_up_fn:
                        iteration_momentum = self.momentum_warm_up_fn(self.trainIteration)
                        for param_group in self.optimizer.param_groups:
                            param_group['momentum'] = iteration_momentum
                    self.trainIteration += 1
                x_input, y_target = x_input.to(self.device), y_target.to(self.device)  # Move to CUDA if possible
                self.optimizer.zero_grad()  # Zero the gradient buffers
                output = self.model.train()(x_input)
                loss = self.loss_criterion(output, y_target)
                loss.backward()  # Backpropagation
                self.strategy.during_training(opt=self.optimizer)
                self.optimizer.step()
                #self.strategy.enforce_prunedness()
                self.train_loss(loss.item(), len(y_target))
                self.train_accuracy(Utils.categorical_accuracy(y_true=y_target, output=output), len(y_target))
                self.train_k_accuracy(Utils.categorical_accuracy(y_true=y_target, output=output, topk=self.k_accuracy),
                                      len(y_target))




    def rewinding(self, current_epoch: int, n_epochs_finetune: int, desired_sparsity: float) -> None:
        if self.config.lr_rewinding in ['constant']:
            sys.stdout.write(f"Rewinding method {self.config.lr_rewinding}: No rewinding to specific epoch.\n")
            return  # No rewinding
        elif self.config.lr_rewinding == 'lr-rewinding':
            # Learning rate rewinding as in Renda2020
            method = 'lr-rewinding'
            rewind_epoch = self.config.nepochs - n_epochs_finetune + current_epoch
        elif self.config.lr_rewinding in ['copycat', 'copycat_end', 'copycat_start', 'copycat_start_momentum', 'copycat_mid', 'copycat_loss_time', 'follow_along_loss_time', 'copycat_warmup_momentum']:
            method = self.config.lr_rewinding
            rewind_epoch = 'N/A'
            epoch_lr = self.copycat_lr[current_epoch]
            if method in ['copycat_start_momentum']:
                if current_epoch <= self.no_momentum_epochs:
                    # Disable momentum
                    for param_group in self.optimizer.param_groups:
                        param_group['momentum'] = 0
                else:
                    # Reenable momentum
                    for param_group in self.optimizer.param_groups:
                        param_group['momentum'] = self.config['momentum']
            elif method == 'copycat_warmup_momentum':
                if current_epoch > self.no_momentum_epochs:
                    self.momentum_warm_up_fn = None

        elif self.config.lr_rewinding in ['SLR']:
            method = 'SLR'
            if current_epoch <= self.n_warmup_epochs:
                # Warmup is handled in train_epoch
                sys.stdout.write(f"Rewinding method {method}: No rewinding, still in warmup phase. Current lr: {float(self.optimizer.param_groups[0]['lr'])}\n")
                return
            rewind_epoch = 'N/A'
            self.warm_up_fn = None  # If set to None, then warm up has finished
            epoch_lr = self.SLR_lr[current_epoch]
        elif self.config.lr_rewinding in ['copycat_lr_momentum_warmup']:
            method = 'copycat_lr_momentum_warmup'
            if current_epoch <= self.n_warmup_epochs:
                # Warmup is handled in train_epoch
                sys.stdout.write(f"Rewinding method {method}: No rewinding, still in warmup phase. Current lr: {float(self.optimizer.param_groups[0]['lr'])}\n")
                return
            rewind_epoch = 'N/A'
            self.warm_up_fn = None  # If set to None, then warm up has finished
            self.momentum_warm_up_fn = None
            epoch_lr = self.copycat_lr[current_epoch]


        # Rewind learning rate if not copycat/follow-along
        if self.config.lr_rewinding not in ['copycat', 'copycat_end', 'copycat_start', 'copycat_start_momentum', 'copycat_mid', 'copycat_loss_time', 'follow_along_loss_time', 'SLR', 'copycat_warmup_momentum', 'copycat_lr_momentum_warmup']:
            epoch_lr = self.strategy.lr_dict[rewind_epoch]
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = epoch_lr
        if method == 'copycat_start_momentum' and current_epoch <= self.no_momentum_epochs:
            sys.stdout.write(
                f"Rewinding method {method}: No rewinding, still in no momentum phase. Current lr: {float(self.optimizer.param_groups[0]['lr'])}\n")
        else:
            sys.stdout.write(f"Rewinding method {method}: Rewinding to epoch {rewind_epoch}, lr {epoch_lr}.\n")

    def fine_tuning(self, desired_sparsity, n_epochs_finetune, phase=1):
        n_phases = self.config.n_phases or 1
        if phase == 1:
            self.finetuneStartTime = time.time()
        self.trainIteration = 0
        finetuneStartTime = time.time()
        if self.config.lr_rewinding in ['copycat', 'copycat_end', 'copycat_start', 'copycat_start_momentum', 'copycat_mid', 'copycat_loss_time', 'copycat_warmup_momentum']:
            epochLRs = [self.strategy.lr_dict[i] for i in range(1, self.config.nepochs+1, 1)]

            interpolation_width = self.config.nepochs/self.config.n_epochs_finetune # in general not an integer
            reducedLRs = [epochLRs[int(j * interpolation_width)] for j in range(n_epochs_finetune)]
            startVal = float(self.strategy.lr_dict[1])

            if self.config.lr_rewinding == 'copycat_loss_time':
                # Take the minimum of closest loss occurence and lr-rewinding starting epoch
                pruned_train_loss = self.after_pruning_metrics[desired_sparsity]['train']['loss']
                minEpoch = min(self.strategy.train_loss_dict.items(), key=lambda x: abs(x[1] - pruned_train_loss))[0]
                rewind_epoch = min(minEpoch, self.config.nepochs - n_epochs_finetune + 1)   # Can't start later than lr_rewinding
                startLR = self.strategy.lr_dict[rewind_epoch]

            self.copycat_lr = {j+1: (reducedLRs[j]/startVal)*startLR for j in range(n_epochs_finetune)}

        elif self.config.lr_rewinding == 'SLR':
            maxLR = max(self.strategy.lr_dict.values()) # This is equivalent to the initial lr!
            after_warmup_index = min([idx for idx, val in self.strategy.lr_dict.items() if val == maxLR and idx > 0])
            minLR = min(list(self.strategy.lr_dict.values())[after_warmup_index:])
            self.n_warmup_epochs = int(
                0.1 * n_epochs_finetune)  # 10% of retraining time, makes sense for n_epochs_finetune >= 10
            n_warmup_iterations = float(len(self.trainLoader) * self.n_warmup_epochs)

            def warm_up_fn(t):
                return maxLR + 0.5 * (minLR - maxLR) * (1 + cos(pi * t / n_warmup_iterations))

            self.warm_up_fn = warm_up_fn

            # After self.n_warmup_epochs, do copycat_start
            epochLRs = [self.strategy.lr_dict[i] if i >= after_warmup_index else maxLR for i in range(1, self.config.nepochs + 1, 1)]
            remainingEpochs = n_epochs_finetune - self.n_warmup_epochs
            interpolation_width = self.config.nepochs / remainingEpochs  # in general not an integer
            reducedLRs = [epochLRs[int(j * interpolation_width)] for j in range(remainingEpochs)]
            self.SLR_lr = {self.n_warmup_epochs + j + 1: reducedLRs[j] for j in range(remainingEpochs)}

        elif self.config.lr_rewinding == 'copycat_lr_momentum_warmup':
            minLR, maxLR = min(self.strategy.lr_dict.values()), max(self.strategy.lr_dict.values())
            self.n_warmup_epochs = int(0.1 * n_epochs_finetune)  # 10% of retraining time, makes sense for n_epochs_finetune >= 10
            n_warmup_iterations = float(len(self.trainLoader) * self.n_warmup_epochs)
            def warm_up_fn(t):
                return maxLR + 0.5*(minLR - maxLR) * (1 + cos(pi * t/n_warmup_iterations))
            self.warm_up_fn = warm_up_fn
            self.no_momentum_epochs = int(0.1 * n_epochs_finetune)
            n_warmup_iterations = float(len(self.trainLoader) * self.no_momentum_epochs)

            def warm_up_fn_momentum(t):
                return 0.9 + 0.5 * (0 - 0.9) * (1 + cos(pi * t / n_warmup_iterations))

            self.momentum_warm_up_fn = warm_up_fn_momentum

            # After self.n_warmup_epochs, do copycat_start
            epochLRs = [self.strategy.lr_dict[i] for i in range(1, self.config.nepochs+1, 1)]
            remainingEpochs = self.config.n_epochs_finetune - self.n_warmup_epochs
            interpolation_width = self.config.nepochs/remainingEpochs # in general not an integer
            reducedLRs = [epochLRs[int(j * interpolation_width)] for j in range(remainingEpochs)]
            self.copycat_lr = {self.n_warmup_epochs + j + 1: reducedLRs[j] for j in range(remainingEpochs)}


        for epoch in range(1, n_epochs_finetune + 1, 1):
            self.reset_averaged_metrics()
            sys.stdout.write(f"\nDesired sparsity {desired_sparsity} - Finetuning: phase {phase}/{n_phases} | epoch {epoch}/{n_epochs_finetune}\n")
            # Train
            t = time.time()
            self.rewinding(current_epoch=epoch, n_epochs_finetune=n_epochs_finetune, desired_sparsity=desired_sparsity)
            self.train_epoch(data='train')
            self.evaluate_model(data='test')
            sys.stdout.write(
                f"\nTest accuracy after this epoch: {self.test_accuracy.result()} (lr = {float(self.optimizer.param_groups[0]['lr'])})\n")

            dsParam = None
            last_epoch = False
            if epoch == n_epochs_finetune and phase == n_phases:
                # Training complete, log the time
                self.totalFinetuneTime = time.time() - finetuneStartTime
                dsParam = desired_sparsity
                last_epoch = True

            # As opposed to previous runs, push information every epoch, but only link to desired_sparsity at end
            self.log(runTime=time.time() - t, desired_sparsity=dsParam, last_epoch=last_epoch)


    def fill_strategy_information(self):
        # Get the wandb information about lr's and losses and fill the corresponding strategy dicts, which can then be used by rewinders
        # Note: this only works if the reference model has "Dense" strategy
        for row in self.reference_run.history(keys=['learning_rate', 'train.loss'], pandas=False):
            epoch, epoch_lr, train_loss = row['_step'], row['learning_rate'], row['train.loss']
            self.strategy.at_epoch_end(epoch=epoch, epoch_lr=epoch_lr, train_loss=train_loss)

    def run(self):
        # Find the reference run
        self.reference_run = None
        entity, project = wandb.run.entity, wandb.run.project
        api = wandb.Api()
        #filterDict = {"config.run_id": self.config.run_id}
        filterDict = {"$and": [{"config.run_id": self.config.run_id}, {"config.strategy": "Dense"}, {"config.model": self.config.model} , {"config.weight_decay": self.config.weight_decay}]}
        runs = api.runs(f"{entity}/{project}", filters=filterDict)
        for run in runs:
            if run.state == 'failed':
                # Ignore this run
                continue
            self.trainedModelFile = run.summary.get('trained_model_file')
            if self.trainedModelFile is not None:
                run.file(self.trainedModelFile).download(root=wandb.run.dir)
                self.seed = run.config['seed']
                self.reference_run = run
                break
        outputStr = f"Found {self.trainedModelFile}" if self.trainedModelFile is not None else "Nothing found."
        sys.stdout.write(f"Trying to find reference trained model in project: {outputStr}\n")
        assert self.trainedModelFile is not None, "No reference trained model found, Aborting."

        wandb.config.update({'seed':self.seed}) # Push the seed to wandb

        # Set a unique random seed
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        # Remark: If you are working with a multi-GPU model, this function is insufficient to get determinism. To seed all GPUs, use manual_seed_all().
        torch.cuda.manual_seed(self.seed)  # This works if CUDA not available

        torch.backends.cudnn.benchmark = True
        self.get_missing_config()   # Load keys that are missing in the config

        self.trainLoader, self.testLoader = self.get_dataloaders()
        self.model = self.get_model(load_initial=True)  # Load the trained model
        self.define_optimizer_scheduler()
        self.checkpoint_file = os.path.join(wandb.run.dir, self.trainedModelFile)   # Set the checkpoint file to the trainedModelFile
        self.trained_norm_square = Utils.get_model_norm_square(model=self.model)
        # Define strategy
        sparsityList = [
            self.config.gsm_desired_sparsity] if self.config.gsm_desired_sparsity is not None else self.desired_sparsities
        if self.config.strategy == 'IMP':
            self.strategy = strategies.IMP(desired_sparsities=sparsityList, n_epochs_finetune=self.config.n_epochs_finetune, n_phases=self.config.n_phases)
        elif self.config.strategy == 'IRP':
            self.strategy = strategies.IRP(desired_sparsities=sparsityList, n_epochs_finetune=self.config.n_epochs_finetune, n_phases=self.config.n_phases)
        elif self.config.strategy == 'LAMP_IMP':
            self.strategy = strategies.LAMP_IMP(desired_sparsities=sparsityList,
                                           n_epochs_finetune=self.config.n_epochs_finetune, n_phases=self.config.n_phases)
        elif self.config.strategy == 'Uniform_IMP':
            self.strategy = strategies.Uniform_IMP(desired_sparsities=sparsityList,
                                           n_epochs_finetune=self.config.n_epochs_finetune, n_phases=self.config.n_phases)
        elif self.config.strategy == 'UniformPlus_IMP':
            self.strategy = strategies.UniformPlus_IMP(desired_sparsities=sparsityList,
                                           n_epochs_finetune=self.config.n_epochs_finetune, n_phases=self.config.n_phases)
        elif self.config.strategy == 'ERK_IMP':
            self.strategy = strategies.ERK_IMP(desired_sparsities=sparsityList,
                                           n_epochs_finetune=self.config.n_epochs_finetune, n_phases=self.config.n_phases)
        self.strategy.after_initialization(model=self.model)    # To ensure that all parameters are properly set
        self.fill_strategy_information()
        # Run the computations
        self.strategy.at_train_end(model=self.model, finetuning_callback=self.fine_tuning,
                                   restore_callback=self.restore_model, save_model_callback=self.save_model,
                                   after_pruning_callback=self.after_pruning_callback, opt=self.optimizer)
