from copy import deepcopy, copy
import math
from collections import OrderedDict
from argparse import ArgumentParser
from functools import partial
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, Optional, Text
import multiprocessing as mp

import pytorch_lightning as pl
from tqdm import tqdm
import torch
import numpy as np

from e2cnn import gspaces
from e2cnn import nn

from equislt.methods.base import (BaseTrain, BasePrune)
from equislt.r2convolution import (OursR2Conv, DefaultR2Conv)

try:
    from equislt.find_subset_gurobi import find_subset as gurobi_find_subset
except ImportError:
    _gurobi_available = False
else:
    _gurobi_available = True

from equislt.find_subset_ortools import find_subset as ortools_find_subset


def mock_find_subset(problem):
    p_idx, (src, trg) = problem
    mask = np.random.randint(2, size=src.shape)
    abs_error = abs((src * mask).sum() - trg)
    rel_error = abs_error / abs(trg)
    return p_idx, mask, rel_error, abs_error


class SteerableCNN(torch.nn.Module):

    def __init__(self, dim_in, dim_out,
                 base_width=32, fc_width=64,
                 fiber_group='C8', which_conv='default',
                 bn=True, bn_affine=False,
                 **r2conv_kwargs):
        super(SteerableCNN, self).__init__()
        C, H, W = dim_in

        blocks = []
        if fiber_group[0] == 'C':
            self.r2_act = gspaces.Rot2dOnR2(N=int(fiber_group[1:]))
        elif fiber_group[0] == 'D':
            self.r2_act = gspaces.FlipRot2dOnR2(N=int(fiber_group[1:]))
        else:
            raise NotImplementedError()

        R2Conv = DefaultR2Conv if which_conv == 'default' else partial(OursR2Conv, **r2conv_kwargs)

        # the input image is a scalar field, corresponding to the trivial representation
        in_type = nn.FieldType(self.r2_act, C * [self.r2_act.trivial_repr])
        self.input_type = in_type
        out_type = nn.FieldType(self.r2_act, base_width*[self.r2_act.regular_repr])
        blocks.extend([
            #  nn.MaskModule(in_type, H, margin=1),
            R2Conv(in_type, out_type, kernel_size=9, padding=0, bias=False, padding_mode='circular'),
            nn.InnerBatchNorm(out_type, affine=bn_affine) if bn else nn.IdentityModule(out_type),
            nn.PointwiseMaxPool(out_type, kernel_size=2),
            nn.ReLU(out_type, inplace=True)
            ])  # 28 -> 20 -> 10

        in_type = out_type
        out_type = nn.FieldType(self.r2_act, fc_width*[self.r2_act.regular_repr])
        blocks.extend([
            R2Conv(in_type, out_type, kernel_size=7, padding=0, bias=False, padding_mode='circular'),
            nn.InnerBatchNorm(out_type, affine=bn_affine) if bn else nn.IdentityModule(out_type),
            #  nn.PointwiseMaxPool(out_type, kernel_size=2),
            nn.ReLU(out_type, inplace=True),
            ])  # 10 -> 4 -> 2

        group_pool = nn.GroupPooling(out_type)
        out_type = group_pool.out_type
        blocks.append(group_pool)

        global_pool = nn.PointwiseAdaptiveMaxPool(out_type, 1)
        out_type = global_pool.out_type
        blocks.append(global_pool)

        self.gcnn = nn.ModuleList(blocks)

        self.fcn = torch.nn.ModuleList([
            torch.nn.Linear(out_type.size, fc_width, bias=False),
            torch.nn.BatchNorm1d(fc_width, affine=bn_affine) if bn else torch.nn.Identity(),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(fc_width, dim_out, bias=False),
            ])

    def forward(self, x: torch.Tensor):
        x = nn.GeometricTensor(x, self.input_type)
        for i, mod in enumerate(self.gcnn):
            x = mod(x)
        x = x.tensor.flatten(1)
        for i, mod in enumerate(self.fcn):
            x = mod(x)
        return x


class E2CNN(BaseTrain):
    def __init__(self,
                 feature_size: Tuple[int],
                 num_classes: int,
                 base_width: int,
                 fc_width: int,
                 fiber_group: Text,
                 which_conv: Text,
                 bn: bool, bn_affine: bool,
                 basis_zero: float = 1e-3,
                 basis_upfactor: int = 3,
                 basis_margin_radius: float = 2.,
                 basis_decay_radius: float = 1.,
                 basis_margin_arc_edges: float = 1.,
                 basis_decay_arc_edges: float = 1.,
                 **kwargs
    ):
        super().__init__(**kwargs)
        self.feature_size = feature_size
        self.num_classes = num_classes

        self.encoder = SteerableCNN(feature_size, num_classes,
                                    base_width=base_width, fc_width=fc_width,
                                    fiber_group=fiber_group, which_conv=which_conv,
                                    bn=bn, bn_affine=bn_affine,
                                    basis_zero=basis_zero, basis_upfactor=basis_upfactor,
                                    basis_margin_radius=basis_margin_radius,
                                    basis_decay_radius=basis_decay_radius,
                                    basis_margin_arc_edges=basis_margin_arc_edges,
                                    basis_decay_arc_edges=basis_decay_arc_edges)

    @staticmethod
    def add_model_specific_args(parent_parser: ArgumentParser):
        parent_parser = super(E2CNN, E2CNN).add_model_specific_args(parent_parser)
        parser = parent_parser.add_argument_group("e2cnn")
        parser.add_argument("--base_width", type=int, default=24)
        parser.add_argument("--fc_width", type=int, default=48)
        parser.add_argument("--fiber_group", type=str, default='C8')
        parser.add_argument("--which_conv", type=str, default='ours', choices=['default', 'ours'])
        parser.add_argument("--bn", type=eval, default=True, choices=[True, False])
        parser.add_argument("--bn_affine", type=eval, default=False, choices=[True, False])

        parser.add_argument("--basis_zero", type=float, default=1e-3)
        parser.add_argument("--basis_upfactor", type=int, default=3)
        parser.add_argument("--basis_margin_radius", type=float, default=0.)
        parser.add_argument("--basis_decay_radius", type=float, default=1.)
        parser.add_argument("--basis_margin_arc_edges", type=float, default=0.5)
        parser.add_argument("--basis_decay_arc_edges", type=float, default=0.75)
        return parent_parser

    @property
    def learnable_params(self) -> List[Dict[str, Any]]:
        return [{"name": "model", "params": self.encoder.parameters()},]

    def forward(self, data) -> List[Dict[str, Any]]:
        x, y = data
        logits = self.encoder(x)
        return {'logits': logits}


class PruneE2CNN(BasePrune):
    arch = E2CNN

    def init_src_from_trg_mod(self, mod):
        try:
            return super().init_src_from_trg_mod(mod)
        except TypeError:
            if not isinstance(mod, (DefaultR2Conv, OursR2Conv)):
                raise TypeError('not R2Conv')
        assert(mod.bias is None)
        c_in = mod.in_type
        c_out = mod.out_type
        n_in = len(c_in)
        n_out = len(c_out)
        d = mod.kernel_size
        card = c_out.gspace.fibergroup.order()
        inp_repr = c_in.representations[0]
        mul = d**2 * card**2 if not inp_repr.is_trivial() else d**2

        overparam_factor = 2 * int(math.ceil(math.log2(self.target_net_depth * mul * n_in * n_out/self.eps) * self.overparam_factor))

        n_int = n_in * overparam_factor
        c_int = nn.FieldType(c_out.gspace, n_int * [inp_repr,])

        R2Conv = type(mod)
        conv1 = R2Conv(c_in, c_int, kernel_size=1, padding=0, bias=False, initialize=False)
        with torch.no_grad():
            # We pre-prune so that:
            # A. Only the identity operator at the origin is considered (basis = 0)
            # B. Diamond shape
            conv1.weights.fill_(0.)
            for n in range(n_in):
                if R2Conv == OursR2Conv:
                    # Element 0 is the identity
                    conv1.weights[0, n*overparam_factor:(n+1)*overparam_factor, n].uniform_(-self.bound, self.bound)
                else:
                    # For the default implementation the 1x1 convs do not
                    # contain immediately the identity.. we would have to append
                    # it or construct it and add it to the basis tensor
                    conv1.weights.view(n_int, n_in, -1)[n*overparam_factor:(n+1)*overparam_factor, n, 0].uniform_(-self.bound, self.bound)

        conv2 = R2Conv(c_int, c_out, kernel_size=mod.kernel_size,
                       padding=mod.padding, bias=False,
                       padding_mode=mod.padding_mode,
                       stride=mod.stride, dilation=mod.dilation,
                       groups=mod.groups, initialize=False)
        with torch.no_grad():
            conv2.weights.uniform_(-self.bound, self.bound)
            if R2Conv == OursR2Conv:
                conv2.basis.copy_(mod.basis)

        src_mod = [
            conv1,
            nn.ReLU(c_int, inplace=True),
            conv2,
        ]
        return overparam_factor, src_mod

    def prepare_problems(self, target_model):
        problems = super().prepare_problems(target_model.gcnn, prefix='gcnn')
        problems.update(super().prepare_problems(target_model.fcn, prefix='fcn'))
        return problems

    def build_src_model(self, solutions, source_model):
        source_model.gcnn = super().build_src_model(solutions, source_model.gcnn, prefix='gcnn')
        source_model.fcn = super().build_src_model(solutions, source_model.fcn, prefix='fcn')
        return source_model

    @torch.no_grad()
    def prune(self, problems, draft=False):
        if draft:
            solver = 'mock'
        else:
            solver = self.solver
        if solver == 'gurobi':
            assert(_gurobi_available)
            find_subset = partial(gurobi_find_subset,
                                  eps=self.eps,
                                  check_trg_lt_eps=self.check_w_lt_eps,
                                  num_threads=self.num_threads,
                                  debug=self.debug)
        elif solver == 'ortools':
            find_subset = partial(ortools_find_subset,
                                  eps=self.eps,
                                  check_trg_lt_eps=self.check_w_lt_eps,
                                  num_threads=self.num_threads,
                                  timeout=self.timeout,
                                  debug=self.debug)
        else:
            find_subset = mock_find_subset

        solutions = dict()
        for name, (type_, over_fact, src_mod, trg_mod) in tqdm(problems.items()):
            if type_ == 'gcnn':
                ours = isinstance(trg_mod, OursR2Conv)
                trg_W = trg_mod.weights if ours else trg_mod.weights.view(len(trg_mod.out_type), len(trg_mod.in_type), -1).permute(2, 0, 1)
                src_W1 = src_mod[0].weights if ours else src_mod[0].weights.view(len(src_mod[0].out_type), len(src_mod[0].in_type), -1).permute(2, 0, 1)
                src_W2 = src_mod[2].weights if ours else src_mod[2].weights.view(len(src_mod[2].out_type), len(src_mod[2].in_type), -1).permute(2, 0 ,1)
                mask_W2 = torch.empty_like(src_W2, dtype=torch.uint8)
                B, n_out, n_in = trg_W.shape

                idx_iter = torch.cartesian_prod(torch.arange(B), torch.arange(n_out), torch.arange(n_in))
                problem_space = []
                for b, i, j in idx_iter:
                    trg_w = float(trg_W[b, i, j])
                    n_int_slice = slice(over_fact*j, over_fact*(j+1))
                    src_W1_ = src_W1[0, n_int_slice, j]
                    src_ws = src_W1_ * src_W2[b, i, n_int_slice]
                    src_ws_pos = src_ws[src_W1_ >= 0.]
                    src_ws_neg = src_ws[src_W1_ < 0.]
                    src_ws_pos = src_ws_pos.cpu().numpy()
                    src_ws_neg = src_ws_neg.cpu().numpy()

                    problem_space.append((src_ws_pos, trg_w))  # positives
                    problem_space.append((src_ws_neg, trg_w))  # negatives

                #with mp.Pool(processes=self.num_workers) as pool:
                with mp.get_context("spawn").Pool(processes=self.num_workers) as pool:
                    masks = list(tqdm(pool.imap_unordered(find_subset, enumerate(problem_space)), total=len(problem_space)))
                masks.sort(key=lambda x: x[0])  # First is the index of the problem

                max_rel_error = 0.
                max_abs_error = 0.
                for (p_idx, mask, rel_error, abs_error) in masks:
                    b, i, j = idx_iter[p_idx // 2]
                    n_int_slice = slice(over_fact*j, over_fact*(j+1))
                    src_W1_ = src_W1[0, n_int_slice, j]
                    which = src_W1_ >= 0. if p_idx % 2 == 0 else src_W1_ < 0
                    mask_W2[b, i, n_int_slice][which] = torch.from_numpy(mask).to(dtype=torch.uint8)
                    max_rel_error = max(max_rel_error, rel_error)
                    max_abs_error = max(max_abs_error, abs_error)

                mask_W2 = mask_W2 if ours else mask_W2.permute(1, 2, 0).reshape(-1)
                src_mod[2].weights.masked_fill_(~mask_W2.bool(), 0.)
                src_mod = nn.SequentialModule(*src_mod)
            elif type_ == 'fcn':
                trg_W = trg_mod.weight
                src_W1 = src_mod[0].weight
                src_W2 = src_mod[2].weight
                mask_W2 = torch.empty_like(src_W2, dtype=torch.uint8)
                n_out, n_in = trg_W.shape

                idx_iter = torch.cartesian_prod(torch.arange(n_out), torch.arange(n_in))
                problem_space = []
                for i, j in idx_iter:
                    trg_w = float(trg_W[i, j])
                    n_int_slice = slice(over_fact*j, over_fact*(j+1))
                    src_W1_ = src_W1[n_int_slice, j]
                    src_ws = src_W1_ * src_W2[i, n_int_slice]
                    src_ws_pos = src_ws[src_W1_ >= 0.]
                    src_ws_neg = src_ws[src_W1_ < 0.]
                    src_ws_pos = src_ws_pos.cpu().numpy()
                    src_ws_neg = src_ws_neg.cpu().numpy()

                    problem_space.append((src_ws_pos, trg_w))  # positives
                    problem_space.append((src_ws_neg, trg_w))  # negatives

                #with mp.Pool(processes=self.num_workers) as pool:
                with mp.get_context("spawn").Pool(processes=self.num_workers) as pool:
                    masks = list(tqdm(pool.imap_unordered(find_subset, enumerate(problem_space)), total=len(problem_space)))
                masks.sort(key=lambda x: x[0])  # First is the index of the problem

                max_rel_error = 0.
                max_abs_error = 0.
                for (p_idx, mask, rel_error, abs_error) in masks:
                    i, j = idx_iter[p_idx // 2]
                    n_int_slice = slice(over_fact*j, over_fact*(j+1))
                    src_W1_ = src_W1[n_int_slice, j]
                    which = src_W1_ >= 0. if p_idx % 2 == 0 else src_W1_ < 0
                    mask_W2[i, n_int_slice][which] = torch.from_numpy(mask).to(dtype=torch.uint8)
                    max_rel_error = max(max_rel_error, rel_error)
                    max_abs_error = max(max_abs_error, abs_error)

                src_mod[2].weight.masked_fill_(~mask_W2.bool(), 0.)
                src_mod = torch.nn.Sequential(*src_mod)
            else:
                raise ValueError('neither r2conv or linear')

            solutions[name] = dict(model=src_mod, mask=mask_W2,
                                   max_rel_error=max_rel_error,
                                   max_abs_error=max_abs_error)
        return solutions
