import torch

from .attn_grad_rollout import GradAttnRollout, GradAttn
from .quantify_query import QuantifyQuery
from .._reducer import ModulePosthocReduce

class GradCrossAttnRollout(GradAttn):
    def __init__(self, discard_ratio=0.9, multihead_reduce='mean', residual_connect=True, norm=True,
                 clip=(0, -1), compress_reduce=lambda x, dim=1:x[:, 0]):
        super().__init__(discard_ratio, multihead_reduce, residual_connect, norm)
        self.compress_reduce = compress_reduce
        self.clip=clip
    
    def solve(self, forward_in, forward_out, backward_in, backward_out):
        _r = super().solve(forward_in, forward_in[0], backward_in, backward_out)
        _r = self.compress_reduce(_r, dim=1)
        return _r[:, self.clip[0]:self.clip[1]][:, None, :]

class QuantifyQueryIn(QuantifyQuery):
    def __init__(self, reduce_method='sum', norm=True, clip=(0, -1)):
        super().__init__(reduce_method, norm)
        self.clip=clip
    def solve(self, forward_in, forward_out, backward_in, backward_out):
        h = forward_out
        dh = backward_in
        dhh = (dh * h)
        # dhh = h
        dhh = torch.relu(dhh)
        dhh = self.reduce_method(dhh)
        # dhh = self.reduce_method(backward_in)
        if self.norm:
            dhh = dhh / dhh.sum(dim=-1, keepdim=True)
            # dhh = (dhh - dhh.min(dim=-1, keepdim=True).values) / dhh.max(dim=-1, keepdim=True).values
        
        
        return dhh[:, self.clip[0]:self.clip[1]][:, None, :]

def _max_pos_merge(_x_next, _x_prev):
    _r = torch.concat([_x_next, _x_prev], dim=1)
    # _r = _r.sum(dim=1, keepdim=True)
    _r = _r.max(dim=1, keepdim=True).values
    return _r


ModulePosthocReduce.register(GradCrossAttnRollout, GradCrossAttnRollout, method=_max_pos_merge)
ModulePosthocReduce.register(QuantifyQueryIn, GradCrossAttnRollout, method=_max_pos_merge)
ModulePosthocReduce.register(GradCrossAttnRollout, QuantifyQueryIn, method=_max_pos_merge)
ModulePosthocReduce.register(GradCrossAttnRollout, GradAttnRollout, method=lambda c, a: a @ (c / c.sum(dim=-1, keepdim=True)).permute([0, 2, 1]))
ModulePosthocReduce.register(GradCrossAttnRollout, GradAttn, method=lambda c, a: a @ (c / c.sum(dim=-1, keepdim=True)).permute([0, 2, 1]))
ModulePosthocReduce.register(QuantifyQueryIn, GradAttnRollout, method=lambda c, a: a @ (c / c.sum(dim=-1, keepdim=True)).permute([0, 2, 1]))
ModulePosthocReduce.register(QuantifyQueryIn, GradAttn, method=lambda c, a: a @ (c / c.sum(dim=-1, keepdim=True)).permute([0, 2, 1]))

class QuantifyQueryInAvg(QuantifyQueryIn):
    pass

def _mean_pos_merge(_x_next, _x_prev):
    _r = torch.concat([_x_next, _x_prev], dim=1)
    # _r = _r.sum(dim=1, keepdim=True)
    _r = _r.mean(dim=1, keepdim=True)
    return _r

class GradCrossAttnRolloutAvg(GradCrossAttnRollout):
    pass
ModulePosthocReduce.register(GradCrossAttnRolloutAvg, GradCrossAttnRolloutAvg, method=_mean_pos_merge)
ModulePosthocReduce.register(QuantifyQueryInAvg, GradCrossAttnRolloutAvg, method=_mean_pos_merge)
ModulePosthocReduce.register(GradCrossAttnRolloutAvg, QuantifyQueryInAvg, method=_mean_pos_merge)
ModulePosthocReduce.register(QuantifyQueryInAvg, GradAttnRollout, method=lambda c, a: a @ (c / c.sum(dim=-1, keepdim=True)).permute([0, 2, 1]))
ModulePosthocReduce.register(QuantifyQueryInAvg, GradAttn, method=lambda c, a: a @ (c / c.sum(dim=-1, keepdim=True)).permute([0, 2, 1]))



