import random
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
import networkx as nx

from ..layers import SpatialConvOrderK
from lib.utils import parse_mask_ratio


class DRIK(nn.Module):
    def __init__(self,
                 adj,
                 d_in,
                 d_hidden,
                 args
                 ):
        super(DRIK, self).__init__()
        self.use_subgraph = args.use_subgraph
        self.d_in = d_in
        self.d_hidden = d_hidden
        self.dataset_name = args.dataset_name
        self.ratio = parse_mask_ratio(args.known_mask_ratio)

        self.dual_adj = args.dual_adj
        self.t_dim = 3
        self.register_buffer('adj', torch.tensor(adj).float())
        self.fc_1 = nn.Linear(1, d_hidden)

        self.gcn_1 = SpatialConvOrderK(c_in=self.t_dim * d_hidden, c_out=d_hidden, support_len=2 * (2 if self.dual_adj else 1), order=1, include_self=args.include_self)
        self.gcn_2 = SpatialConvOrderK(c_in=self.t_dim * d_hidden, c_out=d_hidden, support_len=2 * 1, order=1, include_self=args.include_self)
        self.gcn_3 = SpatialConvOrderK(c_in=self.t_dim * d_hidden, c_out=d_hidden, support_len=2 * 1, order=1, include_self=args.include_self)

        # print(f"gcn_1 include_self: {args.include_self}")
        # print(f"gcn_2 include_self: {args.include_self}")
        # print(f"gcn_3 include_self: {args.include_self}")
        
        self.smooth = nn.Linear(2 * d_hidden, d_hidden)
        self.fc_2 = nn.Linear(d_hidden, 1)

        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=args.gnn_dropout) if args.gnn_dropout > 0 else nn.Identity()

        self.use_layer_norm = args.use_layer_norm
        if self.use_layer_norm:
            self.norm_in = nn.LayerNorm(d_hidden)
            self.norm_gcn1 = nn.LayerNorm(d_hidden)
            self.norm_gcn2 = nn.LayerNorm(d_hidden)
            self.norm_gcn3 = nn.LayerNorm(d_hidden)

        if args.use_adj_drop:
            print("use adj dropout...")
            self.dropout = nn.Dropout(p=0.1)
        else:
            self.dropout = nn.Identity()

        if args.use_init:
            print("use init...")
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_normal_(m.weight, gain=1)
                    nn.init.zeros_(m.bias)
                elif isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                    nn.init.zeros_(m.bias)

        self.mask_fwd_mask = args.mask_fwd_mask
        self.loss_all = args.loss_all
        self.drop_mask_edge = args.drop_mask_edge
        self.use_cross_ref = args.use_cross_ref
        self.use_soft_cross_ref = args.use_soft_cross_ref
        self.use_residual = args.use_residual

    def adj_drop(self, supp, mask):
        supp_update = []
        for i in range(len(supp)):
            s = supp[i].clone().detach()
            s = self.dropout(s)
            supp_update.append(s)
        return supp_update

    def forward(self, x, mask=None, known_set=None, mask_ratio=None, keep_set=None, perturbed_adj=None):
        # Use perturbed adjacency matrix (if provided)
        if perturbed_adj is not None:
            adj = torch.tensor(perturbed_adj, device=x.device, dtype=torch.float32)
        else:
            adj = self.adj.detach().clone()  # adjacency matrix
        if keep_set is not None:
            adj = adj[keep_set, :][:, keep_set]
            x = x[:, :, keep_set, :]
            mask = mask[:, :, keep_set, :]
            # Map known_set to new index space as well
            known_set = [keep_set.index(node) for node in known_set if node in keep_set]

        if self.training:
            if self.use_subgraph:
                n_sub = random.randint(1, len(known_set))
                # Random sampling from known_set
                known_set = random.sample(known_set, n_sub)

            if keep_set is None:
                adj = adj[known_set, :][:, known_set]
                x = x[:, :, known_set, :]
                mask = mask[:, :, known_set, :]

            if mask_ratio is None:
                if isinstance(self.ratio, float):
                # Fixed mask ratio
                    mask_ratio = self.ratio
                else:
                    # Random sampling mask ratio
                    mask_ratio = random.uniform(self.ratio[0], self.ratio[1])
            num_mask = max(1, int(x.shape[2] * mask_ratio))
            # Random masking
            known_mask = random.sample(range(0, x.shape[2]), num_mask)

            if keep_set is not None:
                # Add nodes not in known_set to known_mask
                # Need to map original indices to new index space
                keep_set_indices = set(range(len(keep_set)))  # New index space [0, len(keep_set))
                known_set_mapped = set(known_set)  # known_set already mapped above
                nodes_to_mask = keep_set_indices - known_set_mapped
                known_mask.extend(list(nodes_to_mask))
                # Remove duplicate nodes
                known_mask = list(set(known_mask))

            # drop_mask_edge, remove edges between masked nodes
            if self.drop_mask_edge:
                mask_tensor = torch.zeros(x.shape[2], dtype=torch.bool, device=x.device)
                mask_tensor[known_mask] = True
                edge_mask = torch.outer(mask_tensor, mask_tensor)
                # Set these edges to 0 (i.e., remove edges between masked nodes)
                adj = adj.masked_fill(edge_mask, 0)

            x_ = x.detach().clone()
            x_[:, :, known_mask] = 0


            # Convert to boolean mask
            known_mask_bool = np.zeros(x.shape[2], dtype=bool)
            known_mask_bool[known_mask] = True
            mask_ = mask.detach().clone()
            # Set mask values to 0 for non-known_mask nodes
            if not self.loss_all:
                mask_[:, :, ~known_mask_bool] = 0
            mask4i = mask.detach().clone()
            mask4i[:, :, known_mask_bool] = 0

            imputation = self.impute(x_, mask4i, adj, known_mask)

            if keep_set is not None:
                # inverse
                know_mask_i = list(set(range(len(keep_set))) - set(known_mask))
                adj_i = self.adj.detach().clone()
                adj_i = adj_i[keep_set, :][:, keep_set]

                if self.drop_mask_edge:
                    mask_tensor = torch.zeros(x.shape[2], dtype=torch.bool, device=x.device)
                    mask_tensor[know_mask_i] = True
                    edge_mask = torch.outer(mask_tensor, mask_tensor)
                    # Set these edges to 0 (i.e., remove edges between masked nodes)
                    adj_i = adj_i.masked_fill(edge_mask, 0)

                x_i = imputation.detach().clone()
                # replace with true value, condidering the mask
                x_i[:, :, known_set] = torch.where(mask[:, :, known_set], x[:, :, known_set], x_i[:, :, known_set])

                x_i[:, :, know_mask_i] = 0

                # Convert to boolean mask
                known_mask_bool_i = np.zeros(x.shape[2], dtype=bool)
                known_mask_bool_i[know_mask_i] = True
                mask_i = mask.detach().clone()
                # Set mask values to 0 for non-known_mask nodes
                if not self.loss_all:
                    mask_i[:, :, ~known_mask_bool_i] = 0
                mask4i_i = mask.detach().clone()
                mask4i_i[:, :, known_mask_bool_i] = 0
                imputation_i = self.impute(x_i, mask4i_i, adj_i, know_mask_i)

                imputation = imputation[:, :, known_set, :]
                mask_ = mask_[:, :, known_set, :]
                imputation_i = imputation_i[:, :, known_set, :]
                mask_i = mask_i[:, :, known_set, :]

                return imputation, mask_, known_set, imputation_i, mask_i

            return imputation, mask_, known_set
        else:
            unknown_set = list(set(range(x.shape[2])) - set(known_set))
            if self.drop_mask_edge:
                mask_tensor = torch.zeros(x.shape[2], dtype=torch.bool, device=x.device)
                mask_tensor[unknown_set] = True
                edge_mask = torch.outer(mask_tensor, mask_tensor)
                # Set these edges to 0 (i.e., remove edges between masked nodes)
                adj = adj.masked_fill(edge_mask, 0)

            imputation = self.impute(x, mask, adj, unknown_set)
            imputation = torch.where(mask, x, imputation)
            return imputation

    def impute(self, x, mask, adj, adj_mask_set):
        if self.mask_fwd_mask:
            if self.dual_adj:
                adj_mask_1, adj_mask_2 = adj.detach().clone(), adj.detach().clone()
                adj_mask_1[:, adj_mask_set] = 0
                adj_mask_2[adj_mask_set, :] = 0
                supp1 = SpatialConvOrderK.compute_support(adj_mask_1, x.device)
                supp2 = SpatialConvOrderK.compute_support(adj_mask_2, x.device)
                supp1 = supp1 + supp2
            else:
                adj_mask = adj.detach().clone()
                adj_mask[:, adj_mask_set] = 0
                # adj_mask[adj_mask_set, :] = 0
                supp1 = SpatialConvOrderK.compute_support(adj_mask, x.device)
        else:
            supp1 = SpatialConvOrderK.compute_support(adj, x.device)
        b, s, n, c = x.size()
        x = self.fc_1(x)
        if self.use_layer_norm:
            x = self.norm_in(x)
        x = self.relu(x)  # bs, s, n, dim
        x = rearrange(x, 'b s n d -> b d n s')
        d = x.size(1)

        x1 = rearrange(x, 'b d n s -> b d s n')
        x1 = F.unfold(x1, kernel_size=(self.t_dim, n), padding=(self.t_dim // 2, 0), stride=(1, 1))
        x1 = x1.reshape(b, self.t_dim * d, -1, s)  # b d n' s
        supp_drop = self.adj_drop(supp1, mask)
        x1 = self.gcn_1(x1, supp_drop)
        if self.use_layer_norm:
            x1 = self.norm_gcn1(x1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        x1 = self.relu(x1)
        x1 = self.dropout(x1)

        supp = SpatialConvOrderK.compute_support(adj, x.device)
        x2 = rearrange(x1, 'b d n s -> b d s n')
        x2 = F.unfold(x2, kernel_size=(self.t_dim, n), padding=(self.t_dim // 2, 0), stride=(1, 1))
        x2 = x2.reshape(b, self.t_dim * d, -1, s)  # b d n' s
        supp_drop = self.adj_drop(supp, mask)
        x2 = self.gcn_2(x2, supp_drop)
        if self.use_layer_norm:
            x2 = self.norm_gcn2(x2.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        x2 = self.relu(x2)
        x2 = self.dropout(x2)
        if self.use_residual:
            x2 = x2 + x1

        if self.use_cross_ref:
            # cross reference
            # b d n s
            feat = x2.clone()
            feat = rearrange(feat, 'b d n s -> (b s) n d')
            feat_mask = rearrange(mask, 'b s n d -> (b s) n d')

            cosine_eps = 1e-7
            q = feat.clone()  # b n d
            k = feat.clone().transpose(-2, -1)  # b n d
            q_norm = torch.norm(q, 2, 2, True)
            k_norm = torch.norm(k, 2, 1, True)

            cos_sim = torch.bmm(q, k) / (torch.bmm(q_norm, k_norm) + cosine_eps)  # b n n, [-1, 1]
            cos_sim = (cos_sim + 1.) / 2.  # [0, 1]

            v = feat.clone().transpose(-2, -1)  # b d n
            if self.use_soft_cross_ref:
                # Soft Transfer
                feat_mask_node = feat_mask.any(dim=-1)
                obs = feat_mask_node.unsqueeze(1).float()  # (B⋅S, 1, N)
                unobs = (~feat_mask_node).unsqueeze(1).float()  # (B⋅S, 1, N)
                cos_sim_max = cos_sim * obs
                cos_sim_min = cos_sim * unobs

                cos_sim_max_weight = cos_sim_max / (cos_sim_max.sum(dim=2, keepdim=True) + 1e-8)
                cos_sim_min_weight = cos_sim_min / (cos_sim_min.sum(dim=2, keepdim=True) + 1e-8)

                v_unobs = torch.bmm(v, cos_sim_max_weight.transpose(1, 2))  # (b s) d n
                v_obs = torch.bmm(v, cos_sim_min_weight.transpose(1, 2))  # (b s) d n
            else:
                # Hard Transfer
                cos_sim_max = cos_sim * feat_mask  # observed positions
                cos_sim_max_score, cos_sim_max_index = torch.max(cos_sim_max, dim=1)  # b n
                cos_sim_min = cos_sim * (1. - feat_mask)  # unobserved positions
                cos_sim_min_score, cos_sim_min_index = torch.max(cos_sim_min, dim=1)  # b n

                v_unobs = self.bis(v, 2, cos_sim_max_index)  # find the most similar observed road for each unobserved road
                v_obs = self.bis(v, 2, cos_sim_min_index)  # find the most dissimilar unobserved road for each observed road
                v_unobs = v_unobs * cos_sim_max_score.unsqueeze(1)  # b d n
                v_obs = v_obs * cos_sim_min_score.unsqueeze(1)  # b d n

            v_unobs = rearrange(v_unobs, '(b s) d n -> b d n s', b=b, s=s)
            v_obs = rearrange(v_obs, '(b s) d n -> b d n s', b=b, s=s)

            feat_mask = rearrange(feat_mask, '(b s) n d -> b d n s', b=b, s=s)
            feat_transfer = v_unobs * (1. - feat_mask) + v_obs * feat_mask  # b d n s

            x2 = torch.cat([x2, feat_transfer], dim=1)
            x2 = rearrange(x2, 'b d n s -> b s n d')
            x2 = self.relu(self.smooth(x2))
            x2 = rearrange(x2, 'b s n d -> b d n s')

        # ========================================
        # Output
        # ========================================
        imputation = rearrange(x2, 'b d n s -> b d s n')
        imputation = F.unfold(imputation, kernel_size=(self.t_dim, n), padding=(self.t_dim // 2, 0), stride=(1, 1))
        imputation = imputation.reshape(b, self.t_dim * d, -1, s)  # b d n' s
        supp_drop = self.adj_drop(supp, mask)
        imputation = self.gcn_3(imputation, supp_drop)
        if self.use_layer_norm:
            imputation = self.norm_gcn3(imputation.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        imputation = self.relu(imputation)
        imputation = self.dropout(imputation)
        if self.use_residual:
            imputation = imputation + x2

        imputation = rearrange(imputation, 'b d n s -> b s n d')
        imputation = self.fc_2(imputation)  # b s n d
        return imputation

    def bis(self, input, dim, index):
        # batch index select
        # input: [N, ?, ?, ...]
        # dim: scalar > 0
        # index: [N, idx]
        views = [input.size(0)] + [1 if i != dim else -1 for i in range(1, len(input.size()))]
        expanse = list(input.size())
        expanse[0] = -1
        expanse[dim] = -1
        index = index.view(views).expand(expanse)
        return torch.gather(input, dim, index)

    @staticmethod
    def add_model_specific_args(parser):
        parser.add_argument('--d-hidden', type=int, default=64)
        parser.add_argument('--known-mask-ratio', type=str, default='0.5')
        parser.add_argument('--dual-adj', type=bool, default=False)
        parser.add_argument('--mask-fwd-mask', type=bool, default=False)
        parser.add_argument('--drop-mask-edge', type=bool, default=False)
        parser.add_argument('--loss-all', type=bool, default=False)
        parser.add_argument('--use-cross-ref', type=bool, default=False)
        parser.add_argument('--use-soft-cross-ref', type=bool, default=False)
        parser.add_argument('--use-residual', type=bool, default=False)
        parser.add_argument('--use-layer-norm', type=bool, default=False)
        parser.add_argument('--gnn-dropout', type=float, default=0.1)
        return parser
