import torch 
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from .utils import construct_incr, clone_layer
import sys
sys.path.append('.../utils')
from utils.softdtw_cuda import SoftDTW

class AvgPoolLrp(nn.Module):
    def __init__(self, layer, rule):
        super().__init__()

        rule = {k: v for k,v in rule.items() if k=="epsilon"}  # only epsilont rule is possible
        self.layer = clone_layer(layer)
        self.incr = construct_incr(**rule)

    def forward(self, Rj, Ai):
        
        Ai = torch.autograd.Variable(Ai, requires_grad=True)
        Ai.retain_grad()
        Z = self.layer.forward(Ai)
        Z = self.incr(Z)
        S = (Rj / Z).data 
        (Z * S).sum().backward()
        Ci = Ai.grad 

        return  (Ai * Ci).data


class MaxPoolLrp(nn.Module):
    def __init__(self, layer, rule):
        super().__init__()

        rule = {k: v for k,v in rule.items() if k=="epsilon"}  # only epsilont rule is possible
        self.layer = torch.nn.AvgPool2d(kernel_size=layer.kernel_size)
        self.incr = construct_incr(**rule)

    def forward(self, Rj, Ai):
        
        Ai = torch.autograd.Variable(Ai, requires_grad=True)
        Ai.retain_grad()
        Z = self.layer.forward(Ai)
        Z = self.incr(Z)
        S = (Rj / Z).data 
        (Z * S).sum().backward()
        Ci = Ai.grad 

        return  (Ai * Ci).data
    

class SoftDTWLrp(nn.Module):
    def __init__(self, args, layer, rule):
        super().__init__()

        rule = {k: v for k,v in rule.items() if k=="epsilon"}  # only epsilont rule is possible
        self.softdtw = clone_layer(layer[0])
        self.protos = clone_layer(layer[1])
        if args.model == 'ConvSwitch':
            self.encoding = clone_layer(layer[2])
            self.switch = clone_layer(layer[3])
        self.incr = construct_incr(**rule)
        self.args = args

    def forward(self, Rj, Ai):
        if self.args.model == 'ConvSwitch':
            Ai = Ai.squeeze(2)
            A = self.softdtw.align(self.protos.repeat(Ai.shape[0], 1, 1), Ai)
            n = self.protos.size(1)

            out1 = torch.max(Ai, dim=2)[0].unsqueeze(2).repeat(1,1,n)

            segment_sizes = [int(Ai.shape[2]/n)] * n
            segment_sizes[-1] += Ai.shape[2] - sum(segment_sizes)

            hs = torch.split(Ai, segment_sizes, dim=2)
            hs = [h_.max(dim=2)[0].unsqueeze(dim=2) for h_ in hs]
            out2 = torch.cat(hs, dim=2)

            hs = Ai.unsqueeze(dim=2) * A.unsqueeze(dim=1)
            out3 = hs.max(dim=3)[0]

            concat_out = torch.cat([out1, out2, out3], dim=-1)
            
            raw_attn = self.switch.repeat(Ai.shape[0], 1, 1)

            encode_attn = concat_out * raw_attn
            attn = F.softmax(self.encoding(encode_attn.unsqueeze(2)), dim=-1).squeeze(1)

            if self.args.switch_op == 'ensem':
                ind = torch.topk(attn, n)[1].squeeze(1)
                dummy = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), concat_out.size(1))
                Z = torch.gather(concat_out.transpose(1,2), 1, dummy).transpose(1,2)

            else:
                ind = torch.mean(torch.max(attn, dim=2)[1].squeeze(1).float())
                if ind.item() < n+1:
                    tmp = out1
                    op = 0
                elif ind.item() >= n+1 and ind.item()<= n*2+1:
                    tmp = out2
                    op = 1
                else:
                    tmp = out3
                    op=2
                Z = tmp
                
        if self.args.pool=='DTP':
            #DTP
            Ai = Ai.squeeze(2)
            A = self.softdtw.align(self.protos.repeat(Ai.shape[0], 1, 1), Ai)
            Ai = torch.autograd.Variable(Ai, requires_grad=True)
            Ai.retain_grad()
            
            if self.args.pool_op=='MAX':
                #DTP_MAX
                Ai = Ai.unsqueeze(2) * A.unsqueeze(dim=1)
                Ai = torch.autograd.Variable(Ai, requires_grad=True)
                Ai.retain_grad()

                Z = Ai.max(dim=3)[0]
            else:
                #DTP_AVG
                A = A.clone()
                A /= A.sum(dim=2, keepdim=True)

                Z = torch.bmm(Ai, A.transpose(1, 2))
        elif self.args.pool=='STP':
            #STP
            Ai = Ai.squeeze(2)
            Ai = torch.autograd.Variable(Ai, requires_grad=True)
            Ai.retain_grad()
            n = self.protos.size(1)
            segment_sizes = [int(Ai.shape[2]/n)] * n
            segment_sizes[-1] += Ai.shape[2] - sum(segment_sizes)

            hs = torch.split(Ai, segment_sizes, dim=2)
            if self.args.pool_op=='MAX':
                hs = [h_.max(dim=2)[0].unsqueeze(dim=2) for h_ in hs]
            else:
                hs = [h_.mean(dim=2, keepdim=True) for h_ in hs]
            Z = torch.cat(hs, dim=2)
        else:
            #GTP
            Ai = Ai.squeeze(2)
            Ai = torch.autograd.Variable(Ai, requires_grad=True)
            Ai.retain_grad()
            if self.args.pool_op=='MAX':
                Z = torch.max(Ai, dim=2)[0]
            else:
                Z = torch.mean(Ai, dim=2)

        Z = self.incr(Z)
        S = (Rj / Z).data 
        (Z * S).sum().backward()
        Ci = Ai.grad 
        
        if (self.args.pool=='DTP' and self.args.pool_op=='MAX'):
            return (Ai * Ci).data
        else:
            return  (Ai.unsqueeze(2) * Ci.unsqueeze(2)).data

class AdaptiveAvgPoolLrp(nn.Module):
    def __init__(self, layer, rule):
        super().__init__()

        rule = {k: v for k,v in rule.items() if k=="epsilon"}  # only epsilont rule is possible
        self.layer = clone_layer(layer)
        self.incr = construct_incr(**rule)

    def forward(self, Rj, Ai):
        
        Ai = torch.autograd.Variable(Ai, requires_grad=True)
        Ai.retain_grad()
        Z = self.layer.forward(Ai)
        Z = self.incr(Z)
        S = (Rj / Z).data 
        (Z * S).sum().backward()
        Ci = Ai.grad 

        return  (Ai * Ci).data
