import torch
import torch.nn.functional as F
import numpy as np
from scipy.optimize import fmin_l_bfgs_b

from .base import BasePredictor


class BRSBasePredictor(BasePredictor):
    def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs):
        super().__init__(model, device, **kwargs)
        self.optimize_after_n_clicks = optimize_after_n_clicks
        self.opt_functor = opt_functor

        self.opt_data = None
        self.input_data = None

    def set_input_image(self, image):
        super().set_input_image(image)
        self.opt_data = None
        self.input_data = None

    def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
        pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
        neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)

        for list_indx, clicks_list in enumerate(clicks_lists):
            for click in clicks_list:
                y, x = click.coords
                y, x = int(round(y)), int(round(x))
                y1, x1 = y - radius, x - radius
                y2, x2 = y + radius + 1, x + radius + 1

                if click.is_positive:
                    pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
                else:
                    neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True

        with torch.no_grad():
            pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device)
            neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device)

        return pos_clicks_map, neg_clicks_map

    def get_states(self):
        return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data}

    def set_states(self, states):
        self._set_transform_states(states['transform_states'])
        self.opt_data = states['opt_data']


class FeatureBRSPredictor(BRSBasePredictor):
    def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs):
        super().__init__(model, device, opt_functor=opt_functor, **kwargs)
        self.insertion_mode = insertion_mode
        self._c1_features = None

        if self.insertion_mode == 'after_deeplab':
            self.num_channels = model.feature_extractor.ch
        elif self.insertion_mode == 'after_c4':
            self.num_channels = model.feature_extractor.aspp_in_channels
        elif self.insertion_mode == 'after_aspp':
            self.num_channels = model.feature_extractor.ch + 32
        else:
            raise NotImplementedError

    def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
        points_nd = self.get_points_nd(clicks_lists)
        pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])

        print(image_nd.min(), image_nd.max())

        num_clicks = len(clicks_lists[0])
        bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]

        if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
            self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)

        if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
            self.input_data = self._get_head_input(image_nd, points_nd)

        def get_prediction_logits(scale, bias):
            scale = scale.view(bs, -1, 1, 1)
            bias = bias.view(bs, -1, 1, 1)
            if self.with_flip:
                scale = scale.repeat(2, 1, 1, 1)
                bias = bias.repeat(2, 1, 1, 1)

            scaled_backbone_features = self.input_data * scale
            scaled_backbone_features = scaled_backbone_features + bias
            if self.insertion_mode == 'after_c4':
                x = self.net.feature_extractor.aspp(scaled_backbone_features)
                x = F.interpolate(x,
                                  mode='bilinear',
                                  size=self._c1_features.size()[2:],
                                  align_corners=True)
                x = torch.cat((x, self._c1_features), dim=1)
                scaled_backbone_features = self.net.feature_extractor.head(x)
            elif self.insertion_mode == 'after_aspp':
                scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features)

            pred_logits = self.net.head(scaled_backbone_features)
            pred_logits = F.interpolate(pred_logits,
                                        size=image_nd.size()[2:],
                                        mode='bilinear',
                                        align_corners=True)
            return pred_logits

        self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
        if num_clicks > self.optimize_after_n_clicks:
            opt_result = fmin_l_bfgs_b(func=self.opt_functor,
                                       x0=self.opt_data,
                                       **self.opt_functor.optimizer_params)
            self.opt_data = opt_result[0]

        with torch.no_grad():
            if self.opt_functor.best_prediction is not None:
                opt_pred_logits = self.opt_functor.best_prediction
            else:
                opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
                opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
                opt_pred_logits = get_prediction_logits(*opt_vars)

        return opt_pred_logits

    def _get_head_input(self, image_nd, points):
        with torch.no_grad():
            image_nd, prev_mask = self.net.prepare_input(image_nd)
            coord_features = self.net.get_coord_features(image_nd, prev_mask, points)

            if self.net.rgb_conv is not None:
                x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
                additional_features = None
            elif hasattr(self.net, 'maps_transform'):
                x = image_nd
                additional_features = self.net.maps_transform(coord_features)

            if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp':
                c1, _, c3, c4 = self.net.feature_extractor.backbone(x, additional_features)
                c1 = self.net.feature_extractor.skip_project(c1)

                if self.insertion_mode == 'after_aspp':
                    x = self.net.feature_extractor.aspp(c4)
                    x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True)
                    x = torch.cat((x, c1), dim=1)
                    backbone_features = x
                else:
                    backbone_features = c4
                    self._c1_features = c1
            else:
                backbone_features = self.net.feature_extractor(x, additional_features)[0]

        return backbone_features


class HRNetFeatureBRSPredictor(BRSBasePredictor):
    def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs):
        super().__init__(model, device, opt_functor=opt_functor, **kwargs)
        self.insertion_mode = insertion_mode
        self._c1_features = None

        if self.insertion_mode == 'A':
            self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8])
        elif self.insertion_mode == 'C':
            self.num_channels = 2 * model.feature_extractor.ocr_width
        else:
            raise NotImplementedError

    def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
        points_nd = self.get_points_nd(clicks_lists)
        pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
        num_clicks = len(clicks_lists[0])
        bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]

        if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
            self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)

        if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
            self.input_data = self._get_head_input(image_nd, points_nd)

        def get_prediction_logits(scale, bias):
            scale = scale.view(bs, -1, 1, 1)
            bias = bias.view(bs, -1, 1, 1)
            if self.with_flip:
                scale = scale.repeat(2, 1, 1, 1)
                bias = bias.repeat(2, 1, 1, 1)

            scaled_backbone_features = self.input_data * scale
            scaled_backbone_features = scaled_backbone_features + bias
            if self.insertion_mode == 'A':
                if self.net.feature_extractor.ocr_width > 0:
                    out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features)
                    feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features)

                    context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
                    feats = self.net.feature_extractor.ocr_distri_head(feats, context)
                else:
                    feats = scaled_backbone_features
                pred_logits = self.net.feature_extractor.cls_head(feats)
            elif self.insertion_mode == 'C':
                pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features)
            else:
                raise NotImplementedError

            pred_logits = F.interpolate(pred_logits,
                                        size=image_nd.size()[2:],
                                        mode='bilinear',
                                        align_corners=True)
            return pred_logits

        self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
        if num_clicks > self.optimize_after_n_clicks:
            opt_result = fmin_l_bfgs_b(func=self.opt_functor,
                                       x0=self.opt_data,
                                       **self.opt_functor.optimizer_params)
            self.opt_data = opt_result[0]

        with torch.no_grad():
            if self.opt_functor.best_prediction is not None:
                opt_pred_logits = self.opt_functor.best_prediction
            else:
                opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
                opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
                opt_pred_logits = get_prediction_logits(*opt_vars)

        return opt_pred_logits

    def _get_head_input(self, image_nd, points):
        with torch.no_grad():
            image_nd, prev_mask = self.net.prepare_input(image_nd)
            coord_features = self.net.get_coord_features(image_nd, prev_mask, points)

            if self.net.rgb_conv is not None:
                x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
                additional_features = None
            elif hasattr(self.net, 'maps_transform'):
                x = image_nd
                additional_features = self.net.maps_transform(coord_features)

            feats = self.net.feature_extractor.compute_hrnet_feats(x, additional_features)

            if self.insertion_mode == 'A':
                backbone_features = feats
            elif self.insertion_mode == 'C':
                out_aux = self.net.feature_extractor.aux_head(feats)
                feats = self.net.feature_extractor.conv3x3_ocr(feats)

                context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
                backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context)
            else:
                raise NotImplementedError

        return backbone_features


class InputBRSPredictor(BRSBasePredictor):
    def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs):
        super().__init__(model, device, opt_functor=opt_functor, **kwargs)
        self.optimize_target = optimize_target

    def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
        points_nd = self.get_points_nd(clicks_lists)
        pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
        num_clicks = len(clicks_lists[0])

        if self.opt_data is None or is_image_changed:
            if self.optimize_target == 'dmaps':
                opt_channels = self.net.coord_feature_ch - 1 if self.net.with_prev_mask else self.net.coord_feature_ch
            else:
                opt_channels = 3
            bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
            self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]),
                                        device=self.device,
                                        dtype=torch.float32)

        def get_prediction_logits(opt_bias):
            input_image, prev_mask = self.net.prepare_input(image_nd)
            dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd)

            if self.optimize_target == 'rgb':
                input_image = input_image + opt_bias
            elif self.optimize_target == 'dmaps':
                if self.net.with_prev_mask:
                    dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias
                else:
                    dmaps = dmaps + opt_bias

            if self.net.rgb_conv is not None:
                x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
                if self.optimize_target == 'all':
                    x = x + opt_bias
                coord_features = None
            elif hasattr(self.net, 'maps_transform'):
                x = input_image
                coord_features = self.net.maps_transform(dmaps)

            pred_logits = self.net.backbone_forward(x, coord_features=coord_features)['instances']
            pred_logits = F.interpolate(pred_logits,
                                        size=image_nd.size()[2:],
                                        mode='bilinear',
                                        align_corners=True)

            return pred_logits

        self.opt_functor.init_click(get_prediction_logits,
                                    pos_mask,
                                    neg_mask,
                                    self.device,
                                    shape=self.opt_data.shape)
        if num_clicks > self.optimize_after_n_clicks:
            opt_result = fmin_l_bfgs_b(func=self.opt_functor,
                                       x0=self.opt_data.cpu().numpy().ravel(),
                                       **self.opt_functor.optimizer_params)

            self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(
                self.device)

        with torch.no_grad():
            if self.opt_functor.best_prediction is not None:
                opt_pred_logits = self.opt_functor.best_prediction
            else:
                opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data)
                opt_pred_logits = get_prediction_logits(*opt_vars)

        return opt_pred_logits
