from copy import deepcopy, copy
import math
from collections import OrderedDict
from argparse import ArgumentParser
from functools import partial
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, Optional, Text
import multiprocessing as mp

from tqdm import tqdm
import torch
from torch import nn
import numpy as np

from equislt import graph_equivariant as eq
from equislt.methods.base import (BaseTrain, BasePrune)

try:
    from equislt.find_subset_gurobi import find_subset as gurobi_find_subset
except ImportError:
    _gurobi_available = False
else:
    _gurobi_available = True
from equislt.find_subset_ortools import find_subset as ortools_find_subset


def mock_find_subset(problem):
    p_idx, (src, trg) = problem
    mask = np.random.randint(2, size=src.shape)
    abs_error = abs((src * mask).sum() - trg)
    rel_error = abs_error / abs(trg)
    return p_idx, mask, rel_error, abs_error


class BasicInvariant(nn.Module):
    def __init__(self, dim_in, dim_out,
                 base_width=24, fc_width=96,
                 bias=False):
        super().__init__()

        # build network architecture using config file
        self.equi_layers = nn.ModuleList([
            eq.layer_2_to_2(dim_in, base_width, bias=bias),
            nn.ReLU(inplace=True),
            eq.layer_2_to_2(base_width, base_width * 2, bias=bias),
            nn.ReLU(inplace=True),
            eq.layer_2_to_1(base_width * 2, base_width * 4, bias=bias),
            ])

        self.fcn = nn.ModuleList([
            nn.Linear(base_width * 4, fc_width, bias=bias),
            nn.ReLU(inplace=True),
            nn.Linear(fc_width, dim_out, bias=bias)
            ])

    def forward(self, x):
        for i, mod in enumerate(self.equi_layers):
            x = mod(x)
        x = x.sum(dim=2)  # invariant global pool
        for i, mod in enumerate(self.fcn):
            x = mod(x)
        return x


class InvariantGraphNets(BaseTrain):
    def __init__(self,
                 feature_size: Tuple[int],
                 num_classes: int,
                 base_width: int,
                 fc_width: int,
                 **kwargs
    ):
        super().__init__(**kwargs)
        self.feature_size = feature_size
        self.num_classes = num_classes

        self.encoder = BasicInvariant(feature_size, num_classes,
            base_width=base_width, fc_width=fc_width, bias=False)


    @staticmethod
    def add_model_specific_args(parent_parser: ArgumentParser):
        parent_parser = super(InvariantGraphNets, InvariantGraphNets).add_model_specific_args(parent_parser)
        parser = parent_parser.add_argument_group("e2cnn")
        parser.add_argument("--base_width", type=int, default=24)
        parser.add_argument("--fc_width", type=int, default=96)
        return parent_parser

    @property
    def learnable_params(self) -> List[Dict[str, Any]]:
        return [{"name": "model", "params": self.encoder.parameters()},]

    def forward(self, data) -> List[Dict[str, Any]]:
        x, y = data
        logits = self.encoder(x)
        return {'logits': logits}


class PruneInvariantGraphNets(BasePrune):
    arch = InvariantGraphNets

    def init_src_from_trg_mod(self, mod):
        try:
            return super().init_src_from_trg_mod(mod)
        except TypeError:
            if not isinstance(mod, (eq.layer_2_to_2, eq.layer_2_to_1)):
                raise TypeError('not graph equivariant')

        n_in = mod.input_depth
        n_out = mod.output_depth
        is_2_to_2 = isinstance(mod, eq.layer_2_to_2)
        mul = 15 if is_2_to_2 else 5  # XXX n**2 + 1 is certainly larger

        overparam_factor = 2 * int(math.ceil(math.log2(self.target_net_depth * mul * n_in * n_out/self.eps) * self.overparam_factor))

        n_int = n_in * overparam_factor

        layer1 = eq.layer_2_to_2(n_in, n_int, bias=False)
        with torch.no_grad():
            # We pre-prune so that:
            # A. Only the identity operator at the origin is considered (basis = 0)
            # B. Diamond shape
            layer1.coeffs.fill_(0.)
            for n in range(n_in):
                # Element 0 is the identity
                layer1.coeffs[n, n*overparam_factor:(n+1)*overparam_factor, 0].uniform_(-self.bound, self.bound)

        layer2 = type(mod)(n_int, n_out, bias=False)
        with torch.no_grad():
            layer2.coeffs.uniform_(-self.bound, self.bound)

        src_mod = [
            layer1,
            nn.ReLU(inplace=True),
            layer2,
        ]
        return overparam_factor, src_mod

    def prepare_problems(self, target_model):
        problems = super().prepare_problems(target_model.equi_layers, prefix='graph')
        problems.update(super().prepare_problems(target_model.fcn, prefix='fcn'))
        return problems

    def build_src_model(self, solutions, source_model):
        source_model.equi_layers = super().build_src_model(solutions, source_model.equi_layers, prefix='graph')
        source_model.fcn = super().build_src_model(solutions, source_model.fcn, prefix='fcn')
        return source_model

    @torch.no_grad()
    def prune(self, problems, draft=False):
        if draft:
            solver = 'mock'
        else:
            solver = self.solver
        if solver == 'gurobi':
            assert(_gurobi_available)
            find_subset = partial(gurobi_find_subset,
                                  eps=self.eps,
                                  check_trg_lt_eps=self.check_w_lt_eps,
                                  num_threads=self.num_threads,
                                  debug=self.debug)
        elif solver == 'ortools':
            find_subset = partial(ortools_find_subset,
                                  eps=self.eps,
                                  check_trg_lt_eps=self.check_w_lt_eps,
                                  num_threads=self.num_threads,
                                  timeout=self.timeout,
                                  debug=self.debug)
        else:
            find_subset = mock_find_subset

        solutions = dict()
        for name, (type_, over_fact, src_mod, trg_mod) in tqdm(problems.items()):
            if type_ == 'graph':
                trg_W = trg_mod.coeffs
                src_W1 = src_mod[0].coeffs
                src_W2 = src_mod[2].coeffs
                mask_W2 = torch.empty_like(src_W2, dtype=torch.uint8)
                n_in, n_out, B = trg_W.shape

                idx_iter = torch.cartesian_prod(torch.arange(B), torch.arange(n_out), torch.arange(n_in))
                problem_space = []
                for b, i, j in idx_iter:
                    trg_w = float(trg_W[j, i, b])
                    n_int_slice = slice(over_fact*j, over_fact*(j+1))

                    src_W1_ = src_W1[j, n_int_slice, 0]
                    src_ws = src_W1_ * src_W2[n_int_slice, i, b]
                    src_ws_pos = src_ws[src_W1_ >= 0.]
                    src_ws_neg = src_ws[src_W1_ < 0.]
                    src_ws_pos = src_ws_pos.cpu().numpy()
                    src_ws_neg = src_ws_neg.cpu().numpy()

                    problem_space.append((src_ws_pos, trg_w))  # positives
                    problem_space.append((src_ws_neg, trg_w))  # negatives

                #with mp.Pool(processes=self.num_workers) as pool:
                with mp.get_context("spawn").Pool(processes=self.num_workers) as pool:
                    masks = list(tqdm(pool.imap_unordered(find_subset, enumerate(problem_space)), total=len(problem_space)))
                masks.sort(key=lambda x: x[0])  # First is the index of the problem

                max_rel_error = 0.
                max_abs_error = 0.
                for (p_idx, mask, rel_error, abs_error) in masks:
                    b, i, j = idx_iter[p_idx // 2]
                    n_int_slice = slice(over_fact*j, over_fact*(j+1))
                    src_W1_ = src_W1[j, n_int_slice, 0]
                    which = src_W1_ >= 0. if p_idx % 2 == 0 else src_W1_ < 0
                    mask_W2[n_int_slice, i, b][which] = torch.from_numpy(mask).to(dtype=torch.uint8)
                    max_rel_error = max(max_rel_error, rel_error)
                    max_abs_error = max(max_abs_error, abs_error)

                src_mod[2].coeffs.masked_fill_(~mask_W2.bool(), 0.)
                src_mod = nn.Sequential(*src_mod)
            elif type_ == 'fcn':
                trg_W = trg_mod.weight
                src_W1 = src_mod[0].weight
                src_W2 = src_mod[2].weight
                mask_W2 = torch.empty_like(src_W2, dtype=torch.uint8)
                n_out, n_in = trg_W.shape

                idx_iter = torch.cartesian_prod(torch.arange(n_out), torch.arange(n_in))
                problem_space = []
                for i, j in idx_iter:
                    trg_w = float(trg_W[i, j])
                    n_int_slice = slice(over_fact*j, over_fact*(j+1))
                    src_W1_ = src_W1[n_int_slice, j]
                    src_ws = src_W1_ * src_W2[i, n_int_slice]
                    src_ws_pos = src_ws[src_W1_ >= 0.]
                    src_ws_neg = src_ws[src_W1_ < 0.]
                    src_ws_pos = src_ws_pos.cpu().numpy()
                    src_ws_neg = src_ws_neg.cpu().numpy()

                    problem_space.append((src_ws_pos, trg_w))  # positives
                    problem_space.append((src_ws_neg, trg_w))  # negatives

                #with mp.Pool(processes=self.num_workers) as pool:
                with mp.get_context("spawn").Pool(processes=self.num_workers) as pool:
                    masks = list(tqdm(pool.imap_unordered(find_subset, enumerate(problem_space)), total=len(problem_space)))
                masks.sort(key=lambda x: x[0])  # First is the index of the problem

                max_rel_error = 0.
                max_abs_error = 0.
                for (p_idx, mask, rel_error, abs_error) in masks:
                    i, j = idx_iter[p_idx // 2]
                    n_int_slice = slice(over_fact*j, over_fact*(j+1))
                    src_W1_ = src_W1[n_int_slice, j]
                    which = src_W1_ >= 0. if p_idx % 2 == 0 else src_W1_ < 0
                    mask_W2[i, n_int_slice][which] = torch.from_numpy(mask).to(dtype=torch.uint8)
                    max_rel_error = max(max_rel_error, rel_error)
                    max_abs_error = max(max_abs_error, abs_error)

                src_mod[2].weight.masked_fill_(~mask_W2.bool(), 0.)
                src_mod = nn.Sequential(*src_mod)
            else:
                raise ValueError('neither graph equivariant or linear')

            solutions[name] = dict(model=src_mod, mask=mask_W2,
                                   max_rel_error=max_rel_error,
                                   max_abs_error=max_abs_error)
        return solutions

