# -*-coding:utf-8-*-
import math

import numpy as np
import torch
from einops import rearrange, repeat


class ImageStitching(torch.nn.Module):

    def forward(self, x: torch.Tensor):
        B, M, C, W, H = x.shape
        x = rearrange(x, 'b m c w h -> (b m) c w h')

        L = x.shape[0]

        ids = torch.arange(L).cuda()
        indexes = (ids + ids % M * M) % L
        x = torch.gather(x, 0, repeat(indexes, 'l -> l c w h', c=C, w=W, h=H))

        ids = torch.arange(B).view(-1, 1)
        mix_target = (ids + torch.arange(M)) % B
        mix2_target = ((ids - M + 1) + torch.arange(M * 2 - 1) + B) % B

        mix2_p_target = repeat((ids * M + torch.arange(M)) % L, 'b t -> (b m) t', m=M)

        x = rearrange(x, '(b m) c w h -> b m c w h', b=B)
        x = rearrange(x, 'b (m1 m2) c w h -> b c (m1 w) (m2 h)', m1=int(math.sqrt(M)))
        return x, mix_target, mix2_target, mix2_p_target


class MultiObjStitching:
    def __init__(self, num_classes, mix_n=1, mix_p=0.0, smoothing=0.0):
        self.mix_p = mix_p
        self.mix_n = mix_n
        self.mix = ImageStitching()
        self.smoothing = smoothing
        self.num_classes = num_classes

    def _one_hot(self, target, num_classes, on_value=1., off_value=0., device='cuda'):
        return torch.full((target.size()[0], num_classes), off_value, device=device).scatter_(1, target, on_value)

    @torch.no_grad()
    def __call__(self, x, target):
        B = x.shape[0]
        m = self.mix_n
        use_mix = np.random.rand() < self.mix_p and m > 1
        offset = B * torch.distributed.get_rank()
        if use_mix:
            x, mix_target, mix2_target, mix2_p_target = self.mix(x)
            mix_target = (mix_target + offset).cuda()
            mix2_target = (mix2_target + offset).cuda()
            mix2_p_target = (mix2_p_target + offset * m).cuda()

            off_value = self.smoothing / self.num_classes
            on_value = (1.0 - self.smoothing) / m + off_value
            mix_target = self._one_hot(mix_target, self.num_classes, on_value, off_value)
            mix2_p_target = self._one_hot(mix2_p_target, int(self.num_classes * m), on_value, off_value)

            ids = torch.arange(mix_target.shape[1])
            weights = (1.0 - torch.abs(m - ids - 1) / m)
            on_value = (1.0 - self.smoothing) * weights / m + off_value
            mix2_target = self._one_hot(mix2_target, self.num_classes,
                                        on_value.expand([mix_target.shape[0], -1]).cuda(), off_value)
        else:
            target = target.view(-1, 1)
            target = (target + offset).cuda()
            off_value = self.smoothing / self.num_classes
            on_value = (1.0 - self.smoothing) + off_value
            target = self._one_hot(target, self.num_classes, on_value, off_value)
            mix_target = target
            mix2_target = target
            mix2_p_target = target

        return x, mix_target, mix2_target, mix2_p_target
