import warnings

import numpy as np
import torch
from torch_geometric.data import Data, Batch
import torch.nn as nn
from torch_geometric.nn import GATConv
from statsmodels.tsa.stattools import grangercausalitytests
import torch.nn.functional as F
from sklearn.metrics import mutual_info_score

class NodeFeatureEncoder(nn.Module):
    def __init__(self, input_dim, output_dim=64, hidden_dim=128):
        super().__init__()
        self.rnn = nn.GRU(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):  # [B, T, F]
        h_seq, _ = self.rnn(x)     # h_seq: [B, T, H]
        out = self.fc(h_seq)       # [B, T, output_dim]
        return out
class ERA5FeatureEncoder(nn.Module):
    def __init__(self, output_dim=256):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(32, output_dim)

    def get_1time(self, x):  # [B, T, 8, 81, 81]
        B, C, H, W = x.shape
        outs = []
        for i in range(C):
            feat = self.conv(x[:, i:i+1])  # [B, 1, H, W]
            outs.append(self.fc(feat.view(B, -1)))
        return torch.stack(outs, dim=1)  # [B, 8, D]

    def forward(self, x):  # [B, T, 8, 81, 81]
        B, T, D = x.shape
        outs = []
        for i in range(T):
            x_t = x[:, i]  # [B, 8, 81, 81]
            out_t = self.get_1time(x_t)
            outs.append(out_t)
        return torch.stack(outs, dim=1)  #[B, T, 8, D]

class FullGraphEncoder(nn.Module):
    def __init__(self, output_dim=64):
        super().__init__()
        self.output_dim = output_dim

        self.track_encoder_lo = NodeFeatureEncoder(1, output_dim)
        self.track_encoder_la = NodeFeatureEncoder(1, output_dim)
        self.pressure_encoder = NodeFeatureEncoder(1, output_dim)
        self.wind_encoder = NodeFeatureEncoder(1, output_dim)
        # self.era5_encoder = ERA5FeatureEncoder(output_dim)
        self.era5_encoder = NodeFeatureEncoder(32, output_dim)
        self.structure_encoder = NodeFeatureEncoder(16, output_dim)

    def forward(self, node_1d_lo, node_1d_la, node_1d_p, node_1d_v, z200_z500_uv, D):
        B, T, _ = node_1d_lo.shape

        # z200_z500_uv

        track_feat_lo = self.track_encoder_lo(node_1d_lo) # [B, T, D]
        track_feat_la = self.track_encoder_la(node_1d_la)
        pres_feat = self.pressure_encoder(node_1d_p)# [B, T, D]
        wind_feat = self.wind_encoder(node_1d_v)# [B, T, D]

        D_feat = self.structure_encoder(D)  # [B, 9, D=16]--[B, 9, D]
        era5_feats = self.era5_encoder(z200_z500_uv)  # B, T1, 32--- [B, T, D]
        B, T, D_ = era5_feats.shape
        # B,T,15,D
        all_feats = torch.cat([
            track_feat_lo.unsqueeze(2),  # [B, T, 1, D]
            track_feat_la.unsqueeze(2),
            pres_feat.unsqueeze(2),
            wind_feat.unsqueeze(2),
            era5_feats.unsqueeze(2), # [B, T, 1, D]
            D_feat.unsqueeze(1).repeat(1, T, 1, 1)  # [B, T, 9, D]
        ], dim=2)
        all_feats = all_feats.view(B, T * 14, D_)
        # all_feats = torch.concat((all_feats,D_feat),dim=1)  # B, T1*5+9,D_
        return all_feats  # B, T1*5+9,D_

class GATGraphEncoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.gnn = GATConv(input_dim, output_dim, heads=1, concat=False)

    def forward(self, data, return_attention=False):
        if return_attention:
            out, (ei, attn_weights) = self.gnn(
                data.x, data.edge_index, return_attention_weights=True)
            return out, ei, attn_weights
        else:
            return self.gnn(data.x, data.edge_index)

class VAEEncoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)

class VAEDecoder(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim, 128)
        self.fc_out = nn.Linear(128, output_dim)

    def forward(self, z):
        return self.fc_out(F.relu(self.fc1(z)))

class TimestepEncoding(nn.Module):
    def __init__(self, T, N, embed_dim):
        super().__init__()
        self.embed = nn.Embedding(T * N, embed_dim)

    def forward(self, B, T, N):
        """
        output [B, T*N, embed_dim]
        """
        idx = torch.arange(T * N).unsqueeze(0).repeat(B, 1).to(self.embed.weight.device)
        return self.embed(idx)


class CauModule(nn.Module):
    def __init__(self, node_num, his_len):
        super().__init__()
        self.his_len = his_len
        self.graph_encoder = FullGraphEncoder()
        self.gat = GATGraphEncoder(64, 64)
        self.vae_encoder = VAEEncoder(64, latent_dim=16).cuda()
        self.vae_decoder = VAEDecoder(latent_dim=16, output_dim=64).cuda()
        self.vae_time_embed = TimestepEncoding(T=his_len, N=node_num, embed_dim=64).cuda()

    def forward(self, all_modal,image_obs,node_num):  # torch.Size([32, 1, 8, 64, 64])
        '''
                feature node_2d: B, T1, 32
                feature embed_only_env: [B, 9, D=16]
        '''
        # z200_z500_uv B, T1, 32
        # D ： [B, 9, D=16]
        (node_1d_lo, node_1d_la, node_1d_p, node_1d_v, z200_z500_uv, D) = all_modal
        B, T, _ = node_1d_lo.size()
        '''graph encoder'''
        graph_repr, ei, attn, node_embeddings = self.encoder_all_modal_by_graph(all_modal)
        target_data = torch.concat((node_1d_lo[:, :, 0:1],node_1d_la[:, :, 0:1], node_1d_p[:, :, 0:1], node_1d_v[:, :, 0:1]), dim=2)

        '''tdmi_matrix: [16, 4]  env--target'''
        tdmi_env2target = self.compute_tdmi_matrix(node_1d=target_data, z200_z500_uv=image_obs.float(), D_data=D,
                                                   # 12
                                                   target_data=target_data,  # wind
                                                   max_tau=3)

        tdmi_matrix = self.build_tdmi_supervision_matrix(tdmi_env2target, T=self.his_len, num_nodes=node_num)
        loss_tdmi = self.tdmi_constraint_loss_batched(ei, attn, tdmi_matrix, batch_size=B, T=self.his_len, alpha=1.0)

        '''granger causal'''
        granger_matrix = self.compute_granger_matrix(node_1d=target_data, z200_z500_uv=image_obs.float(), D_data=D,
                                                # 12
                                                target_data=target_data)
        # 128，128
        granger_matrix_all_nodes = self.build_tdmi_supervision_matrix(granger_matrix, T=self.his_len, num_nodes=node_num)
        # loss
        loss_granger = self.tdmi_constraint_loss_batched(ei, attn, granger_matrix_all_nodes, batch_size=B, T=self.his_len,
                                                         alpha=1.0)

        '''VAE'''
        # 128,128
        causal_matrix_vae = self.evaluate_vae_causal_matrix(node_embeddings,
                                                       self.vae_encoder, self.vae_decoder
                                                            , self.vae_time_embed, self.his_len, node_num)
        loss_vae = self.tdmi_constraint_loss_batched(ei, attn, causal_matrix_vae, batch_size=B, T=self.his_len,
                                                     alpha=1.0)
        '''do-calculus'''
        B, N, D = node_embeddings.shape  # [B, 120, D]
        T = self.his_len
        num_nodes_per_timestep = node_num
        edge_index = torch.combinations(torch.arange(T * num_nodes_per_timestep), r=2).T
        edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1).long().to(node_embeddings.device)
        causal_matrix_calculus = self.compute_pyg_batch_causal_strength(
            model=self.gat,
            node_embeddings=node_embeddings,  # [B, 120, D]
            edge_index=edge_index,
            method='mean'  # 'mean', 'random'
        )  # [B, 128, 128]

        loss_calculus = self.do_calculus_constraint_loss_batched(ei, attn, causal_matrix_calculus,
                                                                 batch_size=B, T=self.his_len,
                                                                 alpha=1.0)

        loss_cau = (loss_tdmi, loss_granger, loss_vae, loss_calculus)

        graph_repr  # [B, 64]
        return graph_repr, loss_cau

    def do_calculus_constraint_loss_batched(self, edge_index, attention, do_calculus_matrix, batch_size, T=8,
                                            alpha=1.0):
        N = T * (do_calculus_matrix.shape[-1] // T)
        src, dst = edge_index  # [E]

        batch_id = torch.div(src, N, rounding_mode='trunc')

        src_local = src % N
        dst_local = dst % N

        tdmi_vals = do_calculus_matrix[batch_id, src_local, dst_local]  # [E]

        loss = F.mse_loss(attention.squeeze(), tdmi_vals.to(attention.device))

        return alpha * loss


    def compute_tdmi_matrix(self, node_1d, z200_z500_uv, D_data, target_data, max_tau=3):

        B, T, _ = node_1d.shape

        crop = z200_z500_uv[:, :, :, 20:41, 20:41]  # 21x21
        era5_mean = crop.reshape(B, 1, T, -1).mean(dim=-1).permute(0,2,1)  # [B,T,1]
        D_data_mean = D_data.mean(dim=2).unsqueeze(1).repeat(1, T, 1) #[B,T,9]

        env_data = torch.concat((node_1d,era5_mean,D_data_mean),dim=2)  # [B,T,4+1+9]

        B, T, E = env_data.shape

        _, T_, _  = target_data.shape
        assert T == T_

        tdmi_matrix = torch.zeros((E, target_data.shape[2]))

        for e in range(E):
            for t in range(target_data.shape[2]):
                max_mi = 0.0
                for tau in range(1, max_tau + 1):
                    if T - tau <= 1: break
                    x = env_data[:, :-tau, e].reshape(-1).cpu().detach().numpy()
                    y = target_data[:, tau:, t].reshape(-1).cpu().detach().numpy()
                    x = np.round(x * 100).astype(int)
                    y = np.round(y * 100).astype(int)
                    mi = mutual_info_score(x, y)
                    if mi > max_mi:
                        max_mi = mi
                tdmi_matrix[e, t] = max_mi

        # Normalize for stability
        tdmi_matrix = tdmi_matrix / tdmi_matrix.max()
        return tdmi_matrix  # [14, 4]

    def build_tdmi_supervision_matrix(self, tdmi_env2target, T=8, num_nodes=15):

        tdmi_graph = torch.zeros((T * num_nodes, T * num_nodes))
        num_target = tdmi_env2target.size(1)
        num_env = num_nodes - num_target
        for t_src in range(T):
            for t_dst in range(T):
                for i in range(num_target):
                    for j in range(num_env):
                        node_i = t_dst * num_nodes + i
                        node_j = t_src * num_nodes + (num_target + j)
                        tdmi_graph[node_j, node_i] = tdmi_env2target[j, i]  # env → target

        return tdmi_graph  # shape: [112, 112]

    def tdmi_constraint_loss_batched(self, edge_index, attention, tdmi_matrix, batch_size, T=8, alpha=1.0):

        full_tdmi = tdmi_matrix.repeat(batch_size, batch_size)  # [B*120, B*120]

        src, dst = edge_index[0], edge_index[1]
        tdmi_vals = full_tdmi[src, dst].to(attention.device)
        loss = F.mse_loss(attention.squeeze(), tdmi_vals)

        return alpha * loss


    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def evaluate_vae_causal_matrix(self, X, encoder, decoder, time_embed, his_len, node_num, use_cuda=True):

        B, TN, D = X.shape
        device = X.device

        if time_embed is not None:
            time_encoding = time_embed(B, T=his_len, N=node_num)
            X = X + time_encoding

        recon_errors = torch.zeros((TN, TN), device=device)

        with torch.no_grad():
            for i in range(TN):
                x_i = X[:, i, :]  # [B, D]
                mu, logvar = encoder(x_i)
                z = self.reparameterize(mu, logvar)

                for j in range(TN):
                    # t_i < t_j
                    t_i, t_j = i // 16, j // 16
                    if t_i >= t_j:
                        continue

                    x_j = X[:, j, :]  # [B, D]
                    xj_hat = decoder(z)  # [B, D]
                    loss = F.mse_loss(xj_hat, x_j, reduction='none').mean(dim=1)  # [B]
                    recon_errors[i, j] = loss.mean()

        scores = -recon_errors
        scores = (scores - scores.min()) / (scores.max() - scores.min() + 1e-8)

        return scores  # shape: [128, 128]

    '''do-calculus'''
    def compute_pyg_batch_causal_strength(self, model, node_embeddings, edge_index, method='mean', perturb_scale=20):
        was_training = model.training
        from torch_geometric.data import Data, Batch
        B, N, D = node_embeddings.shape
        device = node_embeddings.device

        batch = Batch.from_data_list([
            Data(x=node_embeddings[i], edge_index=edge_index) for i in range(B)
        ])
        with torch.no_grad():
            model.eval()
            output_orig, _, _ = model(batch, return_attention=True)  # [B * N, D']
            output_orig = output_orig.view(B, N, -1)

            causal_strength = torch.zeros(B, N, N, device=device)

            for i in range(N):  #  do(x_i)
                x_cf = node_embeddings.clone()

                if method == 'zero':
                    x_cf[:, i] = 0
                elif method == 'mean':
                    x_cf[:, i] = x_cf.mean(dim=1) * perturb_scale
                elif method == 'random':
                    x_cf[:, i] = torch.randn_like(x_cf[:, i])
                else:
                    raise ValueError(f"Unknown method: {method}")

                cf_batch = Batch.from_data_list([
                    Data(x=x_cf[j], edge_index=edge_index) for j in range(B)
                ])
                output_cf, _, _ = model(cf_batch, return_attention=True)
                output_cf = output_cf.view(B, N, -1)

                delta = (output_cf - output_orig).abs().mean(dim=2)  # [B, N]
                causal_strength[:, i] = delta
        if was_training:
            model.train()

        return causal_strength  # [B, N, N]

    def compute_granger_matrix(self, node_1d, z200_z500_uv, D_data,  # 12
                               target_data, max_lag=1, significance_threshold=0.05):

        B, T, _ = node_1d.shape

        crop = z200_z500_uv[:, :, :, 20:41, 20:41]  #  21x21
        era5_mean = crop.reshape(B, 1, T, -1).mean(dim=-1).permute(0, 2, 1)  # [B,T,1]
        D_data_mean = D_data.mean(dim=2).unsqueeze(1).repeat(1, T, 1)  # [B,T,9]

        env_data = torch.concat((node_1d, era5_mean, D_data_mean), dim=2)  # [B,T,4+1+9]

        B, T, E = env_data.shape
        _, T2, Tgt = target_data.shape
        assert T == T2, "env_data -- target_data: different dimension"

        granger_matrix = np.zeros((E, Tgt))

        for e in range(E):
            for t in range(Tgt):
                all_pvals = []
                for b in range(B):
                    x = env_data[b, :, e]  # [T]
                    y = target_data[b, :, t]  # [T]

                    series = np.stack([y.cpu().detach().numpy(), x.cpu().detach().numpy()], axis=-1)  # (T, 2)

                    try:
                        with warnings.catch_warnings():
                            warnings.simplefilter("ignore")
                            test_result = grangercausalitytests(series, maxlag=max_lag, verbose=False)
                    except:
                        continue

                    try:
                        min_pval = np.min([test_result[lag][0]['ssr_chi2test'][1] for lag in range(1, max_lag + 1)])
                    except:
                        min_pval = 1

                    all_pvals.append(min_pval)

                if len(all_pvals) > 0:
                    mean_pval = np.mean(all_pvals)
                    score = 1 - mean_pval

                    if mean_pval > significance_threshold:
                        score = 0.0

                    granger_matrix[e, t] = score

        if granger_matrix.max() > 0:
            granger_matrix = granger_matrix / granger_matrix.max()

        return torch.tensor(granger_matrix, dtype=torch.float32)

    def encoder_all_modal_by_graph(self, all_modal):
        (node_1d_lo, node_1d_la, node_1d_p, node_1d_v, z200_z500_uv, D) = all_modal
        #  # B, T1*14,D_
        node_embeddings = self.graph_encoder(node_1d_lo, node_1d_la, node_1d_p, node_1d_v, z200_z500_uv, D)  # [B, T*15, D]
        B, total_nodes, D = node_embeddings.shape
        T = 8
        num_nodes_per_timestep = 5+9

        assert total_nodes == T * num_nodes_per_timestep

        total_nodes_single_graph = T * num_nodes_per_timestep
        edge_index = torch.combinations(torch.arange(total_nodes_single_graph), r=2).T
        edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1).long().to(node_embeddings.device)

        graphs = []
        for i in range(B):
            x = node_embeddings[i]  # [120, D]
            graphs.append(Data(x=x, edge_index=edge_index))

        batch = Batch.from_data_list(graphs)

        output, ei, attn = self.gat(batch, return_attention=True)  # output: [B*120, 256]

        output_split = output.view(B, total_nodes_single_graph, -1)

        graph_repr = output_split.mean(dim=1)  # [B, 256]

        return graph_repr, ei, attn, node_embeddings

