from argparse import ArgumentParser
from functools import partial
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, Optional
import multiprocessing as mp

from tqdm import tqdm
import pytorch_lightning as pl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric
from torch_geometric.nn import (GCNConv, MessagePassing)

from equislt.methods.base import (BaseTrain, BasePrune)

try:
    from equislt.find_subset_gurobi import find_subset as gurobi_find_subset
except ImportError:
    _gurobi_available = False
else:
    _gurobi_available = True
from equislt.find_subset_ortools import find_subset as ortools_find_subset


def mock_find_subset(problem):
    p_idx, (src, trg) = problem
    mask = np.random.randint(2, size=src.shape)
    abs_error = abs((src * mask).sum() - trg)
    rel_error = abs_error / abs(trg)
    return p_idx, mask, rel_error, abs_error


class GCNNet(nn.Module):
    def __init__(self, dim_in, dim_out,
                 depth=4, width=64, dropout=0.5):
        super().__init__()
        self.dropout = dropout
        dims = [dim_in,] + depth * [width,] + [dim_out,]
        self.convs = nn.ModuleList([GCNConv(fin, fout, bias=False) for fin, fout in zip(dims, dims[1:])])

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = x.relu()
            x = F.dropout(x, p=self.dropout, training=self.training)
        return self.convs[-1](x, edge_index)


class GCN(BaseTrain):
    def __init__(self,
                 num_features: int,
                 num_classes: int,
                 depth: int,
                 width: int,
                 dropout: float,
                 **kwargs
    ):
        super().__init__(**kwargs)
        self.num_features = num_features
        self.num_classes = num_classes
        self.depth = depth
        self.width = width

        self.encoder = GCNNet(num_features, num_classes,
                              depth=depth,
                              width=width,
                              dropout=dropout)

    @staticmethod
    def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
        parent_parser = super(GCN, GCN).add_model_specific_args(parent_parser)
        parser = parent_parser.add_argument_group("gcn")

        parser.add_argument("--depth", type=int, default=2)
        parser.add_argument("--width", type=int, default=16)
        parser.add_argument("--dropout", type=float, default=0.5)

        return parent_parser

    @property
    def learnable_params(self) -> List[Dict[str, Any]]:
        return [{"name": "model", "params": self.encoder.parameters()},]

    def forward(self, data) -> List[Dict[str, Any]]:
        logits = self.encoder(data)
        return {'logits': logits}


class PruneGCN(BasePrune):
    arch = GCN

    def init_src_from_trg_mod(self, mod):
        return super().init_src_from_trg_mod(mod, type_=torch_geometric.nn.dense.linear.Linear)

    def prepare_problems(self, target_model):
        problems = dict()
        for i, conv in enumerate(target_model.convs):
            problems.update(super().prepare_problems([conv.lin,], prefix=f'gcn_{i}'))
        return problems

    def build_src_model(self, solutions, source_model):
        for i, conv in enumerate(source_model.convs):
            conv.lin = solutions[f'gcn_{i}_0']['model']
        return source_model

    @torch.no_grad()
    def prune(self, problems, draft=False):
        if draft:
            solver = 'mock'
        else:
            solver = self.solver
        if solver == 'gurobi':
            assert(_gurobi_available)
            find_subset = partial(gurobi_find_subset,
                                  eps=self.eps,
                                  check_trg_lt_eps=self.check_w_lt_eps,
                                  num_threads=self.num_threads,
                                  debug=self.debug)
        elif solver == 'ortools':
            find_subset = partial(ortools_find_subset,
                                  eps=self.eps,
                                  check_trg_lt_eps=self.check_w_lt_eps,
                                  num_threads=self.num_threads,
                                  timeout=self.timeout,
                                  debug=self.debug)
        else:
            find_subset = mock_find_subset

        solutions = dict()
        for name, (type_, over_fact, src_mod, trg_mod) in tqdm(problems.items()):
            trg_W = trg_mod.weight
            src_W1 = src_mod[0].weight
            src_W2 = src_mod[2].weight
            mask_W2 = torch.empty_like(src_W2, dtype=torch.uint8)
            n_out, n_in = trg_W.shape

            idx_iter = torch.cartesian_prod(torch.arange(n_out), torch.arange(n_in))
            problem_space = []
            for i, j in idx_iter:
                trg_w = float(trg_W[i, j])
                n_int_slice = slice(over_fact*j, over_fact*(j+1))
                src_W1_ = src_W1[n_int_slice, j]
                src_ws = src_W1_ * src_W2[i, n_int_slice]
                src_ws_pos = src_ws[src_W1_ >= 0.]
                src_ws_neg = src_ws[src_W1_ < 0.]
                src_ws_pos = src_ws_pos.cpu().numpy()
                src_ws_neg = src_ws_neg.cpu().numpy()

                problem_space.append((src_ws_pos, trg_w))  # positives
                problem_space.append((src_ws_neg, trg_w))  # negatives

            #with mp.Pool(processes=self.num_workers) as pool:
            with mp.get_context("spawn").Pool(processes=self.num_workers) as pool:
                masks = list(tqdm(pool.imap_unordered(find_subset, enumerate(problem_space)), total=len(problem_space)))
            masks.sort(key=lambda x: x[0])  # First is the index of the problem

            max_rel_error = 0.
            max_abs_error = 0.
            for (p_idx, mask, rel_error, abs_error) in masks:
                i, j = idx_iter[p_idx // 2]
                n_int_slice = slice(over_fact*j, over_fact*(j+1))
                src_W1_ = src_W1[n_int_slice, j]
                which = src_W1_ >= 0. if p_idx % 2 == 0 else src_W1_ < 0
                mask_W2[i, n_int_slice][which] = torch.from_numpy(mask).to(dtype=torch.uint8)
                max_rel_error = max(max_rel_error, rel_error)
                max_abs_error = max(max_abs_error, abs_error)

            src_mod[2].weight.masked_fill_(~mask_W2.bool(), 0.)
            src_mod = torch.nn.Sequential(*src_mod)
            solutions[name] = dict(model=src_mod, mask=mask_W2,
                                   max_rel_error=max_rel_error,
                                   max_abs_error=max_abs_error)
        return solutions

    @torch.no_grad()
    def test_src_model(self, dataloader, device=None):
        self.source_model.to(device)
        self.target_model.to(device)

        N = 0
        src_acc = 0.
        trg_acc = 0.
        avg_rel_out_error = 0.
        max_rel_out_error = 0.
        for data in tqdm(dataloader):
            data = data.to(device)

            y = data.y[data.test_mask]
            N += y.size(0)
            src_logits = self.source_model(data)[data.test_mask]
            src_acc += src_logits.argmax(-1).eq(y).float().sum()

            trg_logits = self.target_model(data)[data.test_mask]
            norm_trg_logits = torch.linalg.norm(trg_logits, dim=-1)
            trg_acc += trg_logits.argmax(-1).eq(y).float().sum()

            rel_error =  torch.linalg.norm(src_logits - trg_logits, dim=-1) / norm_trg_logits
            avg_rel_out_error += rel_error.sum()
            max_rel_out_error = max(max_rel_out_error, float(rel_error.max()))

        src_acc /= N
        trg_acc /= N
        avg_rel_out_error /= N

        return dict(src_acc=float(src_acc),
                    trg_acc=float(trg_acc),
                    avg_rel_out_error=float(avg_rel_out_error),
                    max_rel_out_error=float(max_rel_out_error))
