from collections import defaultdict
from collections import deque

import pytorch_lightning as pl
import torch
import torch.distributed as dist
from pytorch_lightning.utilities.combined_loader import _shutdown_workers_and_reset_iterator
from torch.utils.data import DataLoader

from utils.optim.adam_cpr import AdamCPR


class KappaAdaptation(pl.Callback):

    def __init__(self, start_epoch, max_adapt_steps, epoch_interval, start_step, step_interval):

        self.start_epoch = start_epoch
        self.epoch_interval = epoch_interval

        self.start_step = start_step
        self.step_interval = step_interval

        self.max_adapt_steps = max_adapt_steps

        self.first = True

    # def on_train_epoch_end(self, trainer, pl_module):
    #     if trainer.current_epoch >= self.start_epoch and isinstance(trainer.optimizers[0], AdamCPR) and trainer.current_epoch % self.epoch_interval == 0:

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if trainer.global_step >= self.start_step and trainer.global_step % self.step_interval == 0 and isinstance(
                trainer.optimizers[0], AdamCPR):

            learning_rate = trainer.optimizers[0].param_groups[0]['lr']

            train_dataloader = trainer.datamodule.train_dataloader()
            train_dataloader = iter(train_dataloader)
            batches = []
            for _ in range(10):
                batch = next(train_dataloader)
                if isinstance(batch, dict):
                    batch = {k: v.to(pl_module.device) for k, v in batch.items()}
                batches.append(batch)

            if isinstance(train_dataloader, DataLoader):
                _shutdown_workers_and_reset_iterator(train_dataloader)
            train_dataloader = None

            with torch.no_grad():
                with torch.autocast(device_type='cuda', dtype=torch.bfloat16):

                    L_ref = 0
                    for batch in batches:
                        logits = pl_module.model(trg_shf_seq=batch['src_seq'], trg_len=batch['src_len']).detach()
                        loss = pl_module.loss_train(logits.view(-1, logits.size(-1)), batch['trg_seq'].view(-1))
                        L_ref += loss / len(batches)

                new_kappa_dict = defaultdict(dict)
                old_params = deque()
                for param_group in trainer.optimizers[0].param_groups:
                    if 'apply_decay' in param_group.keys():
                        if param_group['apply_decay'] == True and 'constrain' in trainer.optimizers[0].mode:
                            for name, param in zip(param_group['names'], param_group['params']):
                                old_params.append(param.detach().clone())
                                new_kappa_dict[trainer.global_rank][name] = torch.tensor(
                                    trainer.optimizers[0].state[param]['kappa'], device=param.device)

                if self.first:
                    max_adapt_steps = 1
                else:
                    max_adapt_steps = self.max_adapt_steps

                for adapt_step in range(max_adapt_steps):

                    for param_group in trainer.optimizers[0].param_groups:
                        if 'apply_decay' in param_group.keys():
                            if param_group['apply_decay'] == True and 'constrain' in trainer.optimizers[0].mode:
                                for name, param in zip(param_group['names'], param_group['params']):

                                    if "mh" in trainer.optimizers[0].mode and trainer.optimizers[0].state[param][
                                        'lagmul'] > 0:
                                        continue

                                    if "std" in trainer.optimizers[0].mode:
                                        std_dev = param.std()
                                        n = float(param.numel())
                                        mean = param.mean()
                                        norm_param = param.sub(mean)
                                        R_grad = norm_param.mul_(2).sub_(2 * norm_param.mean()).div_(n - 1)
                                        R_grad.div_(std_dev.mul_(2))
                                    elif "l2" in trainer.optimizers[0].mode:
                                        R_grad = 2 * param

                                    param.sub_(learning_rate * R_grad)

                    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):

                        L_current = 0
                        for batch in batches:
                            logits = pl_module.model(trg_shf_seq=batch['src_seq'], trg_len=batch['src_len']).detach()
                            loss = pl_module.loss_train(logits.view(-1, logits.size(-1)), batch['trg_seq'].view(-1))
                            L_current += loss / len(batches)

                    if L_current > L_ref:
                        break

                    for param_group in trainer.optimizers[0].param_groups:
                        if 'apply_decay' in param_group.keys():
                            if param_group['apply_decay'] == True and 'constrain' in trainer.optimizers[0].mode:
                                for name, param in zip(param_group['names'], param_group['params']):

                                    if "std" in trainer.optimizers[0].mode:
                                        new_kappa_dict[trainer.global_rank][name] = min(
                                            new_kappa_dict[trainer.global_rank][name], param.std())
                                    elif "l2" in trainer.optimizers[0].mode:
                                        new_kappa_dict[trainer.global_rank][name] = min(
                                            new_kappa_dict[trainer.global_rank][name], param.square().mean())

                for param_group in trainer.optimizers[0].param_groups:
                    if 'apply_decay' in param_group.keys():
                        if param_group['apply_decay'] == True and 'constrain' in trainer.optimizers[0].mode:
                            for name, param in zip(param_group['names'], param_group['params']):
                                if dist.get_world_size() > 1:
                                    dist.all_reduce(new_kappa_dict[trainer.global_rank][name],
                                                    op=dist.ReduceOp.MIN)  # for multi GPU
                                trainer.optimizers[0].state[param]['kappa'] = new_kappa_dict[trainer.global_rank][
                                    name].item()
                                param.copy_(old_params.popleft())
