import torch
import torch.nn.functional as F
from genotypes import PRIMITIVES, Genotype
from operations import *


def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups
    # reshape
    x = x.view(batchsize, groups,
               channels_per_group, height, width)
    x = torch.transpose(x, 1, 2).contiguous()
    # flatten
    x = x.view(batchsize, -1, height, width)
    return x
class MixedOp(nn.Module):

    def __init__(self, C, stride, k):
        super(MixedOp, self).__init__()
        self.k = k
        self.C = C
        self._ops = nn.ModuleList()

        self.conv_1x1 = nn.Conv2d(C - C // self.k, C - C // self.k, kernel_size=1, stride=2, padding=0)
        for primitive in PRIMITIVES:
            op = OPS[primitive](C // self.k, stride, False)
            self._ops.append(op)

    def forward(self, x, weights):
        dim_2 = x.shape[1]
        xtemp = x[:, : dim_2 // self.k, :, :]
        xtemp2 = x[:, dim_2 // self.k:, :, :]
        temp1 = sum(w * op(xtemp) for w, op in zip(weights, self._ops))

        if self.k == 1:
            return temp1

        if temp1.shape[2] == x.shape[2]:
            ans = torch.cat([temp1, xtemp2], dim=1)
        else:


            ans = torch.cat([temp1, self.conv_1x1(xtemp2)], dim=1)


        ans = channel_shuffle(ans, self.k)
        return ans

    def wider(self, k):
        self.k = k
        for op in self._ops:
            op.wider(self.C // k, self.C // k)

    def get_in_channels(self):
        return self.C // self.k

    def get_out_channels(self):

        return self.C

    def get_kernel_size(self):
        return 1
class ECALayer(nn.Module):
    def __init__(self, channel, k_size=3):
        super(ECALayer, self).__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1)//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        # (B, C, 1, 1) → (B, 1, C)
        y = self.avg_pool(x).view(b, 1, c)

        y = self.conv(y)

        y = self.sigmoid(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class SELayer(nn.Module):

    def __init__(self, channel: int, reduction: int = 16):
        super().__init__()
        hidden = max(1, channel // reduction)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channel, hidden, kernel_size=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, channel, kernel_size=1, bias=True)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.avg_pool(x)
        y = self.fc(y)
        y = self.sigmoid(y)
        return x * y


class _CBAMChannel(nn.Module):

    def __init__(self, channel: int, reduction: int = 16):
        super().__init__()
        hidden = max(1, channel // reduction)
        self.mlp = nn.Sequential(
            nn.Conv2d(channel, hidden, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, channel, kernel_size=1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = self.mlp(self.avg_pool(x))
        max_out = self.mlp(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class _CBAMSpatial(nn.Module):

    def __init__(self, kernel_size: int = 7):
        super().__init__()
        assert kernel_size in (3, 7)
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        y = torch.cat([avg_out, max_out], dim=1)
        y = self.conv(y)
        return self.sigmoid(y)

class CBAM(nn.Module):

    def __init__(self, channel: int, reduction: int = 16, spatial_kernel: int = 7):
        super().__init__()
        self.ca = _CBAMChannel(channel, reduction=reduction)
        self.sa = _CBAMSpatial(kernel_size=spatial_kernel)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Channel Attention
        ca = self.ca(x)
        x = x * ca
        # Spatial Attention
        sa = self.sa(x)
        x = x * sa
        return x


class _ECAMChannelECA(nn.Module):

    def __init__(self, channel: int, k_size: int = 3):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1)//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, 1, c)   # (B,1,C)
        y = self.conv(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        return y

class _ECAMSpatialLite(nn.Module):

    def __init__(self, kernel_size: int = 5):
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_map = torch.mean(x, dim=1, keepdim=True)          # (B,1,H,W)
        max_map, _ = torch.max(x, dim=1, keepdim=True)        # (B,1,H,W)
        hint = 0.5 * (avg_map + max_map)                      # (B,1,H,W)
        y = self.conv(hint)
        return self.sigmoid(y)

class ECAM(nn.Module):

    def __init__(self, channel: int, k_c: int = 3, k_s: int = 5):
        super().__init__()
        self.eca = _ECAMChannelECA(channel, k_size=k_c)
        self.spa = _ECAMSpatialLite(kernel_size=k_s)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Channel: ECA
        ca = self.eca(x)     # (B,C,1,1)
        x = x * ca
        # Spatial: Lite Spatial
        sa = self.spa(x)     # (B,1,H,W)
        x = x * sa
        return x

class DualBranchCell(nn.Module):
    def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, k, num_layers):
        super(DualBranchCell, self).__init__()

        self.reduction = False
        self.cell_robust = Cell(steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, k, num_layers)
        self.cell_normal = Cell(steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, k, num_layers)

        self.fusion = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(2 * multiplier * C, multiplier * C, kernel_size=1, bias=False),
            nn.BatchNorm2d(multiplier * C),
            # ECALayer(multiplier * C, k_size=3)
            ECAM(multiplier * C, k_c=3, k_s=5)
            # CBAM(multiplier * C, reduction=16, spatial_kernel=7)
            # SELayer(multiplier * C, reduction=16)
        )

    def forward(self, states, weights, cell_index):
        states_normal = self.cell_normal(states, weights[0], cell_index)
        states_robust = self.cell_robust(states, weights[1], cell_index)
        robust_feature = states_robust[-1]['feature']
        normal_feature = states_normal[-1]['feature']
        fused_feature = torch.cat([robust_feature, normal_feature], dim=1)
        fused_feature = self.fusion(fused_feature)
        new_state = {'feature': fused_feature, 'cell_index': cell_index, 'node_index': -1}
        return states + [new_state]

    def wider(self, k):

        self.cell_robust.wider(k)
        self.cell_normal.wider(k)


class Cell(nn.Module):

    def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, k, num_layers):
        super(Cell, self).__init__()
        self.reduction = reduction
        self.k = k
        self.num_layers = num_layers
        if reduction_prev:
            self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
        else:
            self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
        self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
        self._steps = steps
        self._multiplier = multiplier

        self._ops = nn.ModuleList()
        self._bns = nn.ModuleList()
        for i in range(self._steps):
            for j in range(2 + i):
                stride = 2 if reduction and j < 2 else 1
                op = MixedOp(C, stride, self.k)
                self._ops.append(op)


    def forward(self, states, weights, cell_index):

        state = states[cell_index-1]



        if cell_index == 1:
            s0 = self.preprocess0(state[-2]['feature'])
            s1 = self.preprocess1(state[-1]['feature'])
        else:
            s0 = self.preprocess0(states[cell_index-2][-1]['feature'])
            s1 = self.preprocess1(states[cell_index-1][-1]['feature'])


        all_states = [{'feature': s0, 'cell_index': cell_index, 'node_index': 0},
                      {'feature': s1, 'cell_index': cell_index, 'node_index': 1}]

        offset = 0

        for i in range(self._steps):
            s = sum(self._ops[offset + j](h['feature'], weights[offset + j]) for j, h in enumerate(all_states))
            offset += len(all_states)

            all_states.append({'feature': s, 'cell_index': cell_index, 'node_index': i + 2})

        features_to_concat = [state['feature'] for state in all_states[-self._multiplier:]]
        result = torch.cat(features_to_concat, dim=1)

        all_states.append({'feature': result, 'cell_index': cell_index, 'node_index': 4})

        return all_states


    def wider(self, k):
        self.k = k
        for op in self._ops:
            op.wider(k)

class RobustStem(nn.Module):
    def __init__(self, C_curr):
        super(RobustStem, self).__init__()


        self.conv1 = nn.Conv2d(3, C_curr, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(C_curr)
        self.relu1 = nn.ReLU(inplace=True)


        self.conv2 = nn.Conv2d(C_curr, 2 * C_curr, kernel_size=5, padding=2, bias=False)  # 使用更大卷积核
        self.bn2 = nn.BatchNorm2d(2 * C_curr)
        self.relu2 = nn.ReLU(inplace=True)


        self.dilated_conv = nn.Conv2d(2 * C_curr, 2 * C_curr, kernel_size=3, padding=2, dilation=2, bias=False)
        self.bn_dilated = nn.BatchNorm2d(2 * C_curr)
        self.relu_dilated = nn.ReLU(inplace=True)


        self.group_conv = nn.Conv2d(2 * C_curr, 2 * C_curr, kernel_size=3, padding=1, groups=2, bias=False)
        self.bn_group = nn.BatchNorm2d(2 * C_curr)
        self.relu_group = nn.ReLU(inplace=True)


        self.eca = ECALayer(2 * C_curr)


        self.reduce_channels = nn.Conv2d(2 * C_curr, C_curr, kernel_size=1, padding=0, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)


        x = self.dilated_conv(x)
        x = self.bn_dilated(x)
        x = self.relu_dilated(x)


        x = self.group_conv(x)
        x = self.bn_group(x)
        x = self.relu_group(x)


        x = self.eca(x)


        x = self.reduce_channels(x)
        return x



class Network(nn.Module):

    def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3, k=4):
        super(Network, self).__init__()
        self._C = C
        self._num_classes = num_classes
        self._layers = layers
        self._criterion = criterion
        self._steps = steps
        self._multiplier = multiplier
        self.k = k
        self.input_size = 32

        C_curr = stem_multiplier * C
        self.stem = nn.Sequential(
            nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
            nn.BatchNorm2d(C_curr)
        )
        # self.stem = RobustStem(C_curr)


        C_prev_prev, C_prev, C_curr = C_curr, C_curr, C # 64 256 256   #
        self.cells = nn.ModuleList()
        reduction_prev = False
        for i in range(layers):
            if i in [layers // 3, 2 * layers // 3]:
                C_curr *= 2
                reduction = True

                cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, k, layers,
                            )
            else:
                reduction = False

                cell = DualBranchCell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, k,
                                      layers)

            reduction_prev = reduction
            self.cells.append(cell)
            C_prev_prev, C_prev = C_prev, multiplier * C_curr # 256 256 #


        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(C_prev, num_classes)

        self._initialize_alphas()

    def new(self):
        model_new = Network(self._C, self._num_classes, self._layers, self._criterion).cuda()
        for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
            x.data.copy_(y.data)
        return model_new

    def wider(self, k):
        self.k = k
        for cell in self.cells:
            cell.wider(k)

    def get_softmax(self):
        weights_normal = F.softmax(self.alphas_normal, dim=-1)
        weights_reduce = F.softmax(self.alphas_reduce, dim=-1)
        weights_robust = F.softmax(self.alphas_robust, dim=-1)
        return {'normal': weights_normal, 'reduce': weights_reduce, 'robust': weights_robust}

    def get_equal_softmax(self):
        alphas_normal = nn.Parameter(1e-3 * torch.randn(self.num_edges, self.num_ops))
        alphas_reduce = nn.Parameter(1e-3 * torch.randn(self.num_edges, self.num_ops))
        alphas_robust = nn.Parameter(1e-3 * torch.randn(self.num_edges, self.num_ops))
        weights_normal = F.softmax(alphas_normal, dim=-1)
        weights_reduce = F.softmax(alphas_reduce, dim=-1)
        weights_robust = F.softmax(alphas_robust, dim=-1)
        return {'normal': weights_normal, 'reduce': weights_reduce, 'robust': weights_robust}

    def get_equal_projected_weights(self, cell_type):
        weights = self.get_equal_softmax()[cell_type]
        return weights

    def get_projected_weights(self, cell_type):
        weights = self.get_softmax()[cell_type]
        return weights

    def forward(self, input, weights_dict=None, rm_key=False):


        s0 = s1 = self.stem(input)

        state = [{'feature': s0, 'cell_index': -1, 'node_index': 0},
                  {'feature': s1, 'cell_index': -1, 'node_index': 1}]

        states = [state]
        layers_num = len(self.cells)

        for i, cell in enumerate(self.cells):
            if rm_key:
                if cell.reduction:
                    weights = weights_dict['reduce']
                else:
                    weights = [weights_dict['normal'], weights_dict['robust']]

            else:
                if cell.reduction:
                    weights = self.get_projected_weights('reduce')
                else:
                    weights = [self.get_projected_weights('normal'), self.get_projected_weights('robust')]

            updated_states = cell(states, weights, cell_index=i+1)  # 传递cell_index给Cell


            states.append(updated_states)



        out = self.global_pooling(states[-1][-1]['feature'])
        logits = self.classifier(out.view(out.size(0), -1))
        return logits

    def _loss(self, input, target):
        logits = self(input)
        loss = self._criterion(logits, target)
        return loss




    def _initialize_alphas(self):
        k = sum(1 for i in range(self._steps) for n in range(2 + i))
        num_ops = len(PRIMITIVES)
        self.num_ops = num_ops
        self.num_edges = k

        self.alphas_normal = nn.Parameter(1e-3 * torch.randn(k, num_ops), requires_grad=False)
        self.alphas_reduce = nn.Parameter(1e-3 * torch.randn(k, num_ops), requires_grad=False)
        self.alphas_robust = nn.Parameter(1e-3 * torch.randn(k, num_ops), requires_grad=False)
        self.arch_parameters = [
            self.alphas_normal,
            self.alphas_reduce,
            self.alphas_robust
        ]

    def arch_parameters(self):
        return self.arch_parameters

    def genotype(self):

        def _parse(weights):
            gene = []
            n = 2
            start = 0
            for i in range(self._steps):
                end = start + n
                W = weights[start:end].copy()
                edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]
                for j in edges:
                    k_best = None
                    for k in range(len(W[j])):
                        if k_best is None or W[j][k] > W[j][k_best]:
                            k_best = k
                    gene.append((PRIMITIVES[k_best], j))
                start = end
                n += 1
            return gene

        gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy())
        gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy())
        gene_robust = _parse(F.softmax(self.alphas_robust, dim=-1).data.cpu().numpy())

        concat = range(2 + self._steps - self._multiplier, self._steps + 2)
        genotype = Genotype(
            normal=gene_normal, normal_concat=concat,
            reduce=gene_reduce, reduce_concat=concat,
            robust=gene_robust, robust_concat=concat
        )
        return genotype