from time import time, sleep
from datetime import datetime
import torch
from collections import deque
from torch import autocast
from typing import Union, Tuple, List
import numpy as np

from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results
from nnunetv2.utilities.collate_outputs import collate_outputs
from nnunetv2.training.lr_scheduler.polylr import WarmupPolyLRScheduler
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from additional_logger import AdamCPRLogger
from torch import distributed as dist
from nnunetv2.utilities.helpers import dummy_context
from adam_cpr import AdamCPR, group_cpr_parameters


class nnUNetTrainerAdamCPR(nnUNetTrainer):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        adam_cpr_mode: str = "l2_constrain_mh",
        adam_cpr_lagmul_rate: float = 1.,
        adam_cpr_kappa: float = 1.,
        adam_cpr_kappa_adapt: bool = False,
        adam_cpr_kappa_init_dependent: float | bool = False,
        adam_cpr_kappa_init_warm_start: int | bool = False,
        **kwargs
    ):
        super().__init__(
            plans, configuration, fold, dataset_json, unpack_dataset, device
        )
        self.initial_lr = 1e-2
        self.warmup_steps = 20
        
        
        self.adam_cpr_config = {
            "mode": adam_cpr_mode, "lagmul_rate": adam_cpr_lagmul_rate, "kappa": adam_cpr_kappa, "kappa_adapt": adam_cpr_kappa_adapt,
            "kappa_init_dependent": adam_cpr_kappa_init_dependent, "kappa_init_warm_start": adam_cpr_kappa_init_warm_start,
        }
        
        exp_string = f"{adam_cpr_mode}_lagmul{adam_cpr_lagmul_rate}_kappa{adam_cpr_kappa}_adapt{adam_cpr_kappa_adapt}_dependent{adam_cpr_kappa_init_dependent}_warmupSteps{adam_cpr_kappa_init_warm_start}"
        
        self.output_folder_base = join(nnUNet_results, self.plans_manager.dataset_name,
                                       self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + configuration, exp_string) \
            if nnUNet_results is not None else None
        self.output_folder = join(self.output_folder_base, f'fold_{fold}')
        
        timestamp = datetime.now()
        maybe_mkdir_p(self.output_folder)
        self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" %
                             (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute,
                              timestamp.second))

    def configure_optimizers(self):
        parameters = group_cpr_parameters(self.network, self.adam_cpr_config, avoid_keywords=['bias,norm'])
        
        
        optimizer = AdamCPR(
            params=parameters,
            lr=self.initial_lr,
            betas=(0.9, 0.99),
            mode=self.adam_cpr_config["mode"],
            apply_decay=None,
            lagmul_rate=self.adam_cpr_config["lagmul_rate"],
            kappa=self.adam_cpr_config["kappa"],
            kappa_adapt=self.adam_cpr_config["kappa_adapt"],
            kappa_init_dependent=self.adam_cpr_config["kappa_init_dependent"],
            kappa_init_warm_start=self.adam_cpr_config["kappa_init_warm_start"]
        )

        lr_scheduler = WarmupPolyLRScheduler(
            optimizer=optimizer,
            initial_lr=self.initial_lr,
            max_steps=self.num_epochs,
            warmup_steps=self.warmup_steps,
        )
        
        self.adam_cpr_logger = AdamCPRLogger(parameters[0]['names'])
        
        return optimizer, lr_scheduler
    
    
    def train_step(self, batch: dict) -> dict:
        data = batch['data']
        target = batch['target']

        data = data.to(self.device, non_blocking=True)
        if isinstance(target, list):
            target = [i.to(self.device, non_blocking=True) for i in target]
        else:
            target = target.to(self.device, non_blocking=True)

        self.optimizer.zero_grad()
        
        with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
            output = self.network(data)
            # del data
            l = self.loss(output, target)

        if self.grad_scaler is not None:
            self.grad_scaler.scale(l).backward()
            self.grad_scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            l.backward()
            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()
        
        lagmul_dict = {}
        kappa_dict = {}
        l2_dict = {}
        
        for param_group in self.optimizer.param_groups:
            if 'apply_decay' in param_group.keys():
                if param_group['apply_decay'] == True and 'constrain' in param_group['mode']:
                    for name, param in zip(param_group['names'], param_group['params']):
                        lagmul_dict['lagmul_' + name] = self.optimizer.state[param]['lagmul'].detach().item()
                        kappa_dict['kappa_' + name] = self.optimizer.state[param]['kappa'].item()
                        l2_dict['l2_' + name] = (param.detach() **2).mean().item()

        return_d = {'loss': l.detach().cpu().numpy()}
        return_d.update(lagmul_dict)
        return_d.update(kappa_dict)
        return_d.update(l2_dict)
        return return_d
    
    def on_train_epoch_end(self, train_outputs: List[dict]):
        outputs = collate_outputs(train_outputs)

        if self.is_ddp:
            losses_tr = [None for _ in range(dist.get_world_size())]
            dist.all_gather_object(losses_tr, outputs['loss'])
            loss_here = np.vstack(losses_tr).mean()
        else:
            loss_here = np.mean(outputs['loss'])
        
        mean_dict = {k: np.mean(v) for k, v in outputs.items() if k not in ['loss']}
        
        for k, v in mean_dict.items():
            self.adam_cpr_logger.log(k, v, self.current_epoch)

        self.logger.log('train_losses', loss_here, self.current_epoch)
    
    def train_step_no_backward(self, batch: dict) -> dict:
        data = batch['data']
        target = batch['target']

        data = data.to(self.device, non_blocking=True)
        if isinstance(target, list):
            target = [i.to(self.device, non_blocking=True) for i in target]
        else:
            target = target.to(self.device, non_blocking=True)

        self.optimizer.zero_grad()
        
        with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
            output = self.network(data)
            # del data
            l = self.loss(output, target)

        return {'loss': l.detach().cpu().numpy()}


    def on_epoch_end(self):
        self.logger.log('epoch_end_timestamps', time(), self.current_epoch)

        self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4))
        self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4))
        self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in
                                               self.logger.my_fantastic_logging['dice_per_class_or_region'][-1]])
        self.print_to_log_file(
            f"Epoch time: {np.round(self.logger.my_fantastic_logging['epoch_end_timestamps'][-1] - self.logger.my_fantastic_logging['epoch_start_timestamps'][-1], decimals=2)} s")

        # handling periodic checkpointing
        current_epoch = self.current_epoch
        if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1):
            self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth'))

        # handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this
        if self._best_ema is None or self.logger.my_fantastic_logging['ema_fg_dice'][-1] > self._best_ema:
            self._best_ema = self.logger.my_fantastic_logging['ema_fg_dice'][-1]
            self.print_to_log_file(f"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}")
            self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth'))

        if self.local_rank == 0:
            self.logger.plot_progress_png(self.output_folder)
            self.adam_cpr_logger.plot_progress_png(self.output_folder)

        self.current_epoch += 1
        

class nnUNetTrainerQuickAdamCPR(nnUNetTrainerAdamCPR):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        adam_cpr_mode: str = "l2_constrain",
        adam_cpr_kappa: float = 10.,
        adam_cpr_kappa_factor: float | bool = False,
        adam_cpr_lagmul_rate: float = 1.,
        adam_cpr_kappa_adapt: bool = False,
        adam_cpr_kappa_adapt_grad: bool = False,
        adam_cpr_kappa_init_steps: int | bool = False,
        **kwargs
    ):
        super().__init__(
            plans=plans, configuration=configuration, fold=fold, dataset_json=dataset_json, unpack_dataset=unpack_dataset, device=device,
            adam_cpr_mode=adam_cpr_mode, adam_cpr_kappa=adam_cpr_kappa, adam_cpr_kappa_factor=adam_cpr_kappa_factor, adam_cpr_lagmul_rate=adam_cpr_lagmul_rate,
            adam_cpr_kappa_adapt=adam_cpr_kappa_adapt, adam_cpr_kappa_adapt_grad=adam_cpr_kappa_adapt_grad, adam_cpr_kappa_init_steps=adam_cpr_kappa_init_steps
        )
        self.initial_lr = 1e-2
        self.warmup_steps = 8
        self.num_epochs = 500


class nnUNetTrainerQuickerAdamCPR(nnUNetTrainerAdamCPR):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
        adam_cpr_mode: str = "l2_constrain",
        adam_cpr_kappa: float = 10.,
        adam_cpr_kappa_factor: float | bool = False,
        adam_cpr_lagmul_rate: float = 1.,
        adam_cpr_kappa_adapt: bool = False,
        adam_cpr_kappa_adapt_grad: bool = False,
        adam_cpr_kappa_init_steps: int | bool = False,
        **kwargs
    ):
        super().__init__(
            plans=plans, configuration=configuration, fold=fold, dataset_json=dataset_json, unpack_dataset=unpack_dataset, device=device,
            adam_cpr_mode=adam_cpr_mode, adam_cpr_kappa=adam_cpr_kappa, adam_cpr_kappa_factor=adam_cpr_kappa_factor, adam_cpr_lagmul_rate=adam_cpr_lagmul_rate,
            adam_cpr_kappa_adapt=adam_cpr_kappa_adapt, adam_cpr_kappa_adapt_grad=adam_cpr_kappa_adapt_grad, adam_cpr_kappa_init_steps=adam_cpr_kappa_init_steps
        )
        self.initial_lr = 1e-2
        self.warmup_steps = 8
        self.num_epochs = 100
