import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchnlp.nn import Attention
from torch.nn import Linear, LSTM
from torch_geometric.nn import RGCNConv, TopKPooling, FastRGCNConv
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
from .rgcn_sag_pooling import RGCNSAGPooling
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from learning.model.intervention import CausalEdgeIntervener

"""Data-driven implementation of CURVE (formerly RS2G/MRGCN baseline code)."""


class CURVE(nn.Module):
    def __init__(self, config):
        super(CURVE, self).__init__()
        self.num_features = config.model_config['num_of_classes']
        self.num_relations = config.model_config['num_relations']
        self.num_classes = config.model_config['nclass']
        self.num_layers = config.model_config['num_layers']  # defines number of RGCN conv layers.
        self.hidden_dim = config.model_config['hidden_dim']
        self.layer_spec = None if config.model_config['layer_spec'] is None else list(
            map(int, config.model_config['layer_spec'].split(',')))
        self.lstm_dim1 = config.model_config['lstm_input_dim']
        self.lstm_dim2 = config.model_config['lstm_output_dim']
        self.rgcn_func = FastRGCNConv if config.model_config['conv_type'] == "FastRGCNConv" else RGCNConv
        self.activation = F.relu if config.model_config['activation'] == 'relu' else F.leaky_relu
        self.pooling_type = config.model_config['pooling_type']
        self.readout_type = config.model_config['readout_type']
        self.temporal_type = config.model_config['temporal_type']
        self.device = config.model_config['device']
        self.dropout = config.model_config['dropout']
        self.edge_ext_thresh = config.model_config['edge_ext_thresh']
        self.conv = []
        total_dim = 0

        if self.layer_spec is None:
            if self.num_layers > 0:
                self.conv.append(self.rgcn_func(self.num_features, self.hidden_dim, self.num_relations).to(self.device))
                total_dim += self.hidden_dim
                for _ in range(1, self.num_layers):
                    self.conv.append(self.rgcn_func(self.hidden_dim, self.hidden_dim, self.num_relations).to(self.device))
                    total_dim += self.hidden_dim
            else:
                self.fc0_5 = Linear(self.num_features, self.hidden_dim)
                total_dim += self.hidden_dim
        else:
            if self.num_layers > 0:
                print("using layer specification and ignoring hidden_dim parameter.")
                print("layer_spec: " + str(self.layer_spec))
                self.conv.append(self.rgcn_func(self.num_features, self.layer_spec[0], self.num_relations).to(self.device))
                total_dim += self.layer_spec[0]
                for i in range(1, self.num_layers):
                    self.conv.append(self.rgcn_func(self.layer_spec[i - 1], self.layer_spec[i], self.num_relations).to(self.device))
                    total_dim += self.layer_spec[i]
            else:
                self.fc0_5 = Linear(self.num_features, self.hidden_dim)
                total_dim += self.hidden_dim

        if self.pooling_type == "sagpool":
            self.pool1 = RGCNSAGPooling(
                total_dim,
                self.num_relations,
                ratio=config.model_config['pooling_ratio'],
                rgcn_func=config.model_config['conv_type'],
            )
        elif self.pooling_type == "topk":
            self.pool1 = TopKPooling(total_dim, ratio=config.model_config['pooling_ratio'])

        self.fc1 = Linear(total_dim, self.lstm_dim1)

        if "lstm" in self.temporal_type:
            self.lstm = LSTM(self.lstm_dim1, self.lstm_dim2, batch_first=True)
            self.attn = Attention(self.lstm_dim2)
            self.lstm_decoder = LSTM(self.lstm_dim2, self.lstm_dim2, batch_first=True)
        else:
            self.fc1_5 = Linear(self.lstm_dim1, self.lstm_dim2)

        self.fc2 = Linear(self.lstm_dim2, self.num_classes * 2)

        # ~~~~~~~~~~~~Data-Driven Graph Encoders~~~~~~~~~~~~~~
        # node encoder
        if config.model_config['node_encoder_dim'] == 1:
            self.node_mu = nn.Linear(15, self.num_features)
            self.node_logvar = nn.Linear(15, self.num_features)
        elif config.model_config['node_encoder_dim'] == 2:
            self.node_mu = nn.Sequential(nn.Linear(15, 30), nn.ReLU(), nn.Linear(30, self.num_features))
            self.node_logvar = nn.Sequential(nn.Linear(15, 30), nn.ReLU(), nn.Linear(30, self.num_features))

        # edge encoder: takes in two node embeddings and returns multilabel edge selection.
        if config.model_config['edge_encoder_dim'] == 1:
            self.edge_mu = nn.Linear(2 * 15, self.num_relations)
            self.edge_logvar = nn.Linear(2 * 15, self.num_relations)
        elif config.model_config['edge_encoder_dim'] == 2:
            self.edge_mu = nn.Sequential(nn.Linear(2 * 15, 30), nn.ReLU(), nn.Linear(30, self.num_relations))
            self.edge_logvar = nn.Sequential(nn.Linear(2 * 15, 30), nn.ReLU(), nn.Linear(30, self.num_relations))

        # ~~~~~~~~~~~~CURVE: causal edge intervention module~~~~~~~~~~~~~~
        self.causal_intervener = CausalEdgeIntervener(
            feature_dim=self.num_relations,
            num_prototypes=32,  # e.g., number of environment modes / confounder prototypes
        ).to(self.device)

    def forward(self, sequence):
        # graph extraction component
        graph_list = []
        for i in range(len(sequence)):
            graph = {}
            node_feature_list = sequence[i]
            node_mu = self.node_mu(node_feature_list)
            node_logvar = self.node_logvar(node_feature_list)
            node_std = torch.exp(0.5 * node_logvar)
            node_eps = torch.randn_like(node_std)

            if self.training:
                node_embeddings = node_mu + node_eps * node_std
            else:
                node_embeddings = node_mu

            device = node_embeddings.device
            graph['node_embeddings'] = self.activation(node_embeddings)
            graph['node_logvar'] = node_logvar
            graph['edge_attr'] = []
            graph['edge_index'] = []

            new_arr = torch.ones([len(node_feature_list), len(node_feature_list)], device=device).triu(diagonal=1)
            new_arr_idx = torch.where(new_arr == 1.0)
            combo_list = torch.stack(new_arr_idx).t()

            new_arr_2 = new_arr.flatten().int()
            new_arr_idx2 = torch.where(new_arr_2 == 1.0)

            node_combo_a = node_embeddings.unsqueeze(0).repeat((node_embeddings.size(0), 1, 1))
            node_combo_b = node_embeddings.unsqueeze(1).repeat((1, node_embeddings.size(0), 1))
            node_combo = torch.cat([node_combo_b, node_combo_a], dim=-1).flatten(start_dim=0, end_dim=1)
            node_combinations = node_combo[new_arr_idx2]

            edge_mu = self.edge_mu(node_combinations)
            edge_logvar = self.edge_logvar(node_combinations)

            edge_mu_intervened, intervention_alpha = self.causal_intervener(edge_mu, edge_logvar)
            edge_mu = edge_mu_intervened

            edge_std = torch.exp(0.5 * edge_logvar)
            edge_eps = torch.randn_like(edge_std)
            if self.training:
                edge_vectors = edge_mu + edge_eps * edge_std
            else:
                edge_vectors = edge_mu

            edge_vectors = torch.sigmoid(edge_vectors)
            num_nodes = len(node_feature_list)

            curr_k = min(10, num_nodes - 1)
            if curr_k < 1:
                curr_k = 1

            edge_scores_flat, edge_types_flat = torch.max(edge_vectors, dim=1)

            score_matrix = torch.zeros((num_nodes, num_nodes), device=device)
            type_matrix = torch.zeros((num_nodes, num_nodes), device=device)

            u, v = combo_list[:, 0], combo_list[:, 1]
            score_matrix[u, v] = edge_scores_flat
            score_matrix[v, u] = edge_scores_flat

            type_matrix[u, v] = edge_types_flat.float()
            type_matrix[v, u] = edge_types_flat.float()
            
            _, top_k_indices = torch.topk(score_matrix, k=curr_k, dim=1)
            sources_flat = torch.arange(num_nodes, device=device).repeat_interleave(curr_k)
            targets_flat = top_k_indices.reshape(-1)
            sparse_edge_index = torch.stack([sources_flat, targets_flat], dim=0)
            sparse_edge_types = type_matrix[sources_flat, targets_flat].long()

            full_logvar_map = torch.zeros((num_nodes, num_nodes, self.num_relations), device=device)
            full_logvar_map[u, v] = edge_logvar
            full_logvar_map[v, u] = edge_logvar

            raw_logvars = full_logvar_map[sources_flat, targets_flat]
            sparse_edge_logvar = raw_logvars.gather(1, sparse_edge_types.unsqueeze(1)).squeeze(1)

            final_edge_index = sparse_edge_index
            final_edge_attr = sparse_edge_types
            final_edge_logvar = sparse_edge_logvar

            thresh_mask = edge_scores_flat > self.edge_ext_thresh
            if thresh_mask.sum() > 0:
                strong_pair_indices = thresh_mask.nonzero().squeeze()
                if strong_pair_indices.dim() == 0:
                    strong_pair_indices = strong_pair_indices.unsqueeze(0)

                strong_u = u[strong_pair_indices]
                strong_v = v[strong_pair_indices]

                strong_edge_index = torch.stack(
                    [torch.cat([strong_u, strong_v]), torch.cat([strong_v, strong_u])],
                    dim=0
                )

                strong_types_raw = edge_types_flat[strong_pair_indices]
                strong_types = torch.cat([strong_types_raw, strong_types_raw], dim=0)

                strong_logvar_all = edge_logvar[strong_pair_indices]  # [M, R]
                strong_logvar_scalar = strong_logvar_all.gather(1, strong_types_raw.unsqueeze(1)).squeeze(1)  # [M]
                strong_logvar = torch.cat([strong_logvar_scalar, strong_logvar_scalar], dim=0)

                final_edge_index = torch.cat([final_edge_index, strong_edge_index], dim=1)
                final_edge_attr = torch.cat([final_edge_attr, strong_types], dim=0)
                final_edge_logvar = torch.cat([final_edge_logvar, strong_logvar], dim=0)

            graph['edge_index'] = final_edge_index
            graph['edge_attr'] = final_edge_attr  # LongTensor [E]
            graph['edge_uncertainty'] = torch.exp(final_edge_logvar)  # FloatTensor [E]

            graph_list.append(graph)

        graph_data_list = [
            Data(
                x=g['node_embeddings'],
                edge_index=g['edge_index'],
                edge_attr=g['edge_attr'],
                node_logvar=g['node_logvar'],
                edge_uncertainty=g['edge_uncertainty'],
            )
            for g in graph_list
        ]

        train_loader = DataLoader(graph_data_list, batch_size=len(graph_data_list))
        sequence = next(iter(train_loader)).to(self.device)
        x, edge_index, edge_attr, batch = sequence.x, sequence.edge_index, sequence.edge_attr, sequence.batch

        edge_sigma = getattr(sequence, 'edge_uncertainty', None)
        if edge_sigma is not None:
            if edge_sigma.dim() > 1:
                edge_sigma = edge_sigma.squeeze(-1)
            edge_weights = 1.0 / (edge_sigma + 1e-5)
            edge_weights = edge_weights / edge_weights.mean()
            _ = edge_weights  # avoid lint warnings; integrate into conv if needed

        # RGCN backbone + pooling/readout
        attn_weights = dict()
        outputs = []
        if self.num_layers > 0:
            for i in range(self.num_layers):
                x = self.conv[i](x, edge_index, edge_attr)
                x = self.activation(x)
                x = F.dropout(x, self.dropout, training=self.training)
                outputs.append(x)
            x = torch.cat(outputs, dim=-1)
        else:
            x = self.activation(self.fc0_5(x))

        if self.pooling_type == "sagpool":
            x, edge_index, _, attn_weights['batch'], attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool1(
                x, edge_index, edge_attr=edge_attr, batch=batch
            )
        elif self.pooling_type == "topk":
            x, edge_index, _, attn_weights['batch'], attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool1(
                x, edge_index, edge_attr=edge_attr, batch=batch
            )
        else:
            attn_weights['batch'] = batch

        if self.readout_type == "add":
            x = global_add_pool(x, attn_weights['batch'])
        elif self.readout_type == "mean":
            x = global_mean_pool(x, attn_weights['batch'])
        elif self.readout_type == "max":
            x = global_max_pool(x, attn_weights['batch'])

        x = self.activation(self.fc1(x))

        # temporal modeling
        if self.temporal_type == "mean":
            x = self.activation(self.fc1_5(x.mean(axis=0)))
        elif self.temporal_type == "lstm_last":
            x_predicted, (h, c) = self.lstm(x.unsqueeze(0))
            x = h.flatten()
        elif self.temporal_type == "lstm_sum":
            x_predicted, (h, c) = self.lstm(x.unsqueeze(0))
            x = x_predicted.sum(dim=1).flatten()
        elif self.temporal_type == "lstm_attn":
            x_predicted, (h, c) = self.lstm(x.unsqueeze(0))
            x, attn_weights['lstm_attn_weights'] = self.attn(h.view(1, 1, -1), x_predicted)
            x, (h_decoder, c_decoder) = self.lstm_decoder(x, (h, c))
            x = x.flatten()
        elif self.temporal_type == "lstm_seq":
            x_predicted, (h, c) = self.lstm(x.unsqueeze(0))
            x = x_predicted.squeeze(0)

        causal_feature = x.clone()
        logits_all = self.fc2(x)  # class0_mu..classN_mu, class0_logvar..classN_logvar

        logits_mu = logits_all[:self.num_classes]
        logits_logvar = logits_all[self.num_classes:]

        # diversity loss for confounder/prototype dictionary (kept as in original)
        dict_vectors = self.causal_intervener.confounder_dict
        dict_norm = F.normalize(dict_vectors, p=2, dim=1)
        similarity_matrix = torch.matmul(dict_norm, dict_norm.t())
        identity = torch.eye(dict_vectors.size(0)).to(dict_vectors.device)
        diversity_loss = torch.norm(similarity_matrix - identity, p='fro')

        if self.training:
            return {
                'logits_mu': logits_mu,
                'logits_logvar': logits_logvar,
                'graph_list': graph_list,
                'diversity_loss': diversity_loss * 0.1,
                'mean_alpha': intervention_alpha.mean(),
            }
        else:
            return {
                'output': F.log_softmax(logits_mu, dim=-1),
                'uncertainty': torch.exp(0.5 * logits_logvar).mean(),
                'graph_list': graph_list,
                'mean_alpha': intervention_alpha.mean(),
                'causal_feature': causal_feature,
            }

