import os
import pickle
import tempfile
from typing import *

import hydra
import mlflow
import torch
from torch import Tensor, nn

from utils import assert_type, myprint as print


class SPG(nn.Module):
    def __init__(self, target_module: nn.Module, shift: float, target_name: str, seqname: str, has_task_head: bool = False):
        super().__init__()
        assert_type(target_module, nn.Module)
        assert_type(has_task_head, bool)

        self.target_module = target_module
        self.target_name = target_name
        self.seqname = seqname
        self.has_task_head = has_task_head
        if self.has_task_head:
            assert_type(self.target_module, [nn.ModuleList, nn.ModuleDict])
        # endif

        self.history_mask = dict()  # type: Dict[int, Dict[str, Tensor]]
        self.dict_amax = {}

        self.dict__idx_task__t__h = {}  # type: Dict[int, Dict[int, Dict[str, Tensor]]]

        self.shift = shift
    # enddef

    def forward(self, x: Tensor, idx: Union[int, str] = None) -> Tensor:
        assert_type(x, Tensor)

        if self.has_task_head:
            assert idx is not None

            out = self.target_module[idx](x)
        else:
            assert idx is None

            out = self.target_module(x)
        # endif

        return out
    # enddef

    def standardize_pm1(self, x: Tensor) -> Tensor:
        if torch.all(x == 0):
            pass
        else:
            x = self.standardize(x)
        # endif
        ret = torch.tanh(x * self.shift)

        return ret
    # enddef

    @classmethod
    def standardize(cls, x: Tensor) -> Tensor:
        sh = x.shape
        x = x.view(-1)

        ret = (x - x.mean()) / x.std()

        return ret.view(*sh)
    # enddef

    def register_grad(self, idx_task: int, t: int, epoch: int, h: Dict[str, Tensor]):
        if idx_task not in self.dict__idx_task__t__h.keys():
            self.dict__idx_task__t__h[idx_task] = {}
        # endif

        if t not in self.dict__idx_task__t__h[idx_task].keys():
            self.dict__idx_task__t__h[idx_task][t] = {}
        # endif

        for name, grad in h.items():
            if name in self.dict__idx_task__t__h[idx_task][t].keys():
                grad_prev = self.dict__idx_task__t__h[idx_task][t][name]
            else:
                grad_prev = 0
            # endif

            # just add
            grad_new = grad_prev + grad
            # grad_new = grad

            # average
            # grad_new = (grad_prev * epoch + grad) / (epoch + 1)

            # anealing
            # grad_new = grad_prev + grad / (epoch + 1)

            self.dict__idx_task__t__h[idx_task][t][name] = grad_new
        # endfor
    # enddef

    def freeze_masks(self, idx_task: int, save_artifact: bool = True):
        if idx_task not in self.dict__idx_task__t__h.keys():
            # ablation can take this route.
            return
        # endif

        names = self.dict__idx_task__t__h[idx_task][idx_task].keys()
        history = {}  # type: Dict[str, Tensor]
        for t, dict__name__h in self.dict__idx_task__t__h[idx_task].items():
            assert names == dict__name__h.keys()
            for name, h in dict__name__h.items():
                if name not in history.keys():
                    history[name] = torch.zeros_like(h)
                # endif

                history[name] = torch.max(history[name], self.standardize_pm1(h).abs())
            # endfor
        # endfor
        self.history_mask[idx_task] = history.copy()

        if save_artifact:
            # saving artifacts
            history_prev = {}  # type: Dict[str, Tensor]
            history_curr = {}
            for t, dict__name__h in self.dict__idx_task__t__h[idx_task].items():
                if t < idx_task:
                    for name, h in dict__name__h.items():
                        if name not in history_prev.keys():
                            history_prev[name] = torch.zeros_like(h)
                        # endif

                        history_prev[name] = torch.max(history_prev[name], self.standardize_pm1(h).abs())
                    # endfor
                elif t == idx_task:
                    for name, h in dict__name__h.items():
                        assert name not in history_curr.keys()

                        history_curr[name] = self.standardize_pm1(h).abs()
                    # endfor
                else:
                    raise ValueError(t)
                # endif
            # endfor

            with tempfile.TemporaryDirectory() as dir:
                path__history = os.path.join(dir, f'{idx_task}_{self.target_name}_histories.pkl')
                obj_history = {
                    'history_prev': history_prev,
                    'history_curr': history_curr,
                    }

                with open(path__history, 'wb') as f:
                    pickle.dump(obj_history, f)
                # endwith

                mlflow.log_artifact(path__history)
            # endwith
        # endif
    # enddef

    def compute_a_by_p(self,
                       following_modules: nn.Module,
                       h: Tensor,
                       idx: Optional[int] = None,
                       ) -> Dict[str, Tensor]:
        if self.has_task_head:
            assert idx is not None
        else:
            pass
            assert idx is None
        # endif

        history = {}

        # self.target_module.zero_grad(set_to_none=True)
        following_modules.zero_grad(set_to_none=True)

        h = h.detach()
        output = following_modules(h)
        output.backward()

        tgt = self.target_module[idx] if self.has_task_head else self.target_module
        for n, p in tgt.named_parameters():
            grad = p.grad
            assert grad is not None

            if grad is not None:
                g = grad.data.detach().clone().cpu()

                history[n] = g
            # endif
        # endfor

        # self.target_module.zero_grad(set_to_none=True)
        following_modules.zero_grad(set_to_none=True)

        return history
    # enddef

    def a_max(self, idx_task: int, latest_module: nn.Module) -> Dict[str, Tensor]:
        if idx_task == 0:
            return None
        else:
            if idx_task not in self.dict_amax.keys():
                ret = dict()

                for name_param, param in latest_module.named_parameters():
                    # for t in range(idx_task):
                    #     his = self.history_mask[t]
                    #
                    #     if name_param not in his.keys():
                    #         raise ValueError(f't: {t}, name_param: {name_param} is not found from keys: {his.keys()}')
                    #     # endif
                    #     assert_shape(his[name_param], *param.shape)
                    #
                    #     if name_param not in ret.keys():
                    #         ret[name_param] = his[name_param]
                    #     # endif
                    #
                    #     v1 = torch.max(ret[name_param].to(self.device),
                    #                    his[name_param].to(self.device)).cpu()
                    #     assert_shape(v1, *(his[name_param]).shape)
                    #
                    #     ret[name_param] = v1
                    # # endfor
                    curr = self.history_mask[idx_task - 1][name_param]
                    if idx_task - 1 in self.dict_amax.keys():
                        prev = self.dict_amax[idx_task - 1][name_param]
                    else:
                        prev = curr
                    # endif

                    v1 = torch.max(prev, curr)
                    ret[name_param] = v1

                    # assert torch.all(0 <= ret[name_param]) and torch.all(ret[name_param] <= 1), f'{name_param}: {ret[name_param]}'
                # endfor

                self.dict_amax[idx_task] = ret
            # endif

            return self.dict_amax[idx_task]
        # endif
    # enddef

    def block(self, idx_task: int, name: str, show: bool):
        tgt = self.target_module[idx_task] if self.has_task_head else self.target_module

        a_max = self.a_max(idx_task, tgt)

        for n, p in tgt.named_parameters():
            if p.grad is None:
                msg = ''
            else:
                red = (1 - a_max[n]).to(p.device)
                p.grad.data *= red

                num_0 = red[red == 0].numel()
                num_09 = red[red <= 0.1].numel()
                num_all = red.numel()
                msg = f'[{name}.{n}]' \
                      f' dead: {num_0}/{num_all}({num_0 / num_all * 100:.1f}%)' \
                      f' 0.9: {num_09}/{num_all}({num_09 / num_all * 100:.1f}%)'
            # endif

            if show:
                if len(msg) > 0:
                    print(msg)
                # endif
            # endif
        # endfor
    # enddef

    def count_consumtion(self, idx_task: int, strict: bool) -> Tuple[int, int]:
        tgt = self.target_module[idx_task] if self.has_task_head else self.target_module

        a_max = self.a_max(idx_task, tgt)

        num_all = 0
        num_blocked = 0

        for n, p in tgt.named_parameters():
            if p.grad is None:
                pass
            else:
                num_all += a_max[n].numel()
                if strict:
                    num_blocked += (a_max[n] == 1).sum().item()
                else:
                    num_blocked += a_max[n].sum().item()
                # endif
            # endif
        # endfor

        return num_all, num_blocked
    # enddef

# endclass
