import os
import tqdm
import math
import time
import copy
import pygod
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.transforms import BaseTransform, ToUndirected
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.loader import ClusterData, ClusterLoader
from torch_geometric.utils import add_self_loops, negative_sampling

from .sage import SAGEConv
from .consistency import MLPConsistency, ConsistencyModel, consistency_sampling, ConsistencyCluster


class VGAE_Cluster(nn.Module):
    
    def __init__(
            self,
            in_dim,
            hid_dim,
            temporal=True,
            t_min=0,
            t_max=1024,
            etypes=1,
            threshold=0.5,
            encoder_backend: str = 'sage'
    ):
        super(VGAE_Cluster, self).__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.temporal = temporal
        self.t_min = t_min
        self.t_max = t_max
        self.time_len = self.t_max - self.t_min + 1 if temporal else None
        self.etypes = etypes
        self.threshold = threshold
        self.encoder_backend = encoder_backend

        # encoder backbone selection (default: GraphSAGE). For hetero backend, we
        # currently reuse SAGEConv with relation/time encoding; a dedicated HetGNN
        # implementation can be plugged here without changing interfaces.
        self.enc_shared = SAGEConv(in_dim, hid_dim, temporal=temporal,
                                   time_len=self.time_len, etypes=etypes)  
        self.enc_mu = SAGEConv(hid_dim, hid_dim, temporal=temporal,
                               time_len=self.time_len, etypes=etypes)  
        self.enc_sigma = SAGEConv(hid_dim, hid_dim, temporal=temporal,
                                  time_len=self.time_len, etypes=etypes)  
        self.dec_attr = nn.Linear(hid_dim, in_dim)  
        self.dec_time = nn.Linear(2 * hid_dim, 1)  
        self.dec_type = nn.Linear(2 * hid_dim, self.etypes)  
        self.dec_stru = nn.Linear(2 * hid_dim, 1) 
        self.map_label_e = nn.Linear(1, in_dim, bias=False)  
        self.map_label_d = nn.Linear(1, hid_dim, bias=False)  

    def forward(self, x, pos_edge_index, neg_edge_index,
                label, edge_time=None, edge_type=None):
        self.mean, self.log_std = self.encode(x, pos_edge_index, label,
                                              edge_time, edge_type=edge_type)
        noise = torch.randn_like(self.mean)
        z = self.mean + noise * torch.exp(self.log_std)
        x_, edge_pred, t_, p_ = self.decode(z, pos_edge_index,
                                            neg_edge_index, label)
        return x_, edge_pred, t_, p_

    def encode(self, h, edge_index, label, edge_time=None, edge_type=None):
        h += self.map_label_e(label)

        if self.encoder_backend == 'hetgnn' and self.etypes > 1 and edge_type is not None:
            # Relation-wise aggregation then fuse (mean) to mimic HetGNN-style relation aggregation
            h_shared_list = []
            for r in range(self.etypes):
                rel_mask = (edge_type == r)
                if rel_mask.sum() == 0:
                    continue
                ei_r = edge_index[:, rel_mask]
                et_r = edge_type[rel_mask] if edge_type is not None else None
                etime_r = edge_time[rel_mask] if edge_time is not None else None
                h_shared_list.append(self.enc_shared(h, ei_r, etime_r, et_r))
            if len(h_shared_list) == 0:
                h_shared = torch.zeros_like(h)
            else:
                h_shared = torch.stack(h_shared_list, dim=0).mean(0)
            h = torch.relu(h_shared)

            mean_list = []
            for r in range(self.etypes):
                rel_mask = (edge_type == r)
                if rel_mask.sum() == 0:
                    continue
                ei_r = edge_index[:, rel_mask]
                et_r = edge_type[rel_mask] if edge_type is not None else None
                etime_r = edge_time[rel_mask] if edge_time is not None else None
                mean_list.append(self.enc_mu(h, ei_r, etime_r, et_r))
            mean = torch.stack(mean_list, dim=0).mean(0) if len(mean_list) > 0 else torch.zeros_like(h)

            sigma_list = []
            for r in range(self.etypes):
                rel_mask = (edge_type == r)
                if rel_mask.sum() == 0:
                    continue
                ei_r = edge_index[:, rel_mask]
                et_r = edge_type[rel_mask] if edge_type is not None else None
                etime_r = edge_time[rel_mask] if edge_time is not None else None
                sigma_list.append(self.enc_sigma(h, ei_r, etime_r, et_r))
            log_std = torch.stack(sigma_list, dim=0).mean(0) if len(sigma_list) > 0 else torch.zeros_like(h)
        else:
            h = self.enc_shared(h, edge_index, edge_time, edge_type)
            h = torch.relu(h)
            mean = self.enc_mu(h, edge_index, edge_time, edge_type)
            log_std = self.enc_sigma(h, edge_index, edge_time, edge_type)
        return mean, log_std

    def decode(self, z, pos_edge_index, neg_edge_index, label):
        z += self.map_label_d(label)
        x_ = self.dec_attr(z)

        pos_ze = torch.cat([z[pos_edge_index[0]], z[pos_edge_index[1]]], dim=1)
        neg_ze = torch.cat([z[neg_edge_index[0]], z[neg_edge_index[1]]], dim=1)

        pos_edge_pred = self.dec_stru(pos_ze).squeeze(-1)
        neg_edge_pred = self.dec_stru(neg_ze).squeeze(-1)
        edge_pred = torch.cat([pos_edge_pred, neg_edge_pred], dim=0)

        t_ = self.dec_time(pos_ze).squeeze(-1) if self.temporal else None
        p_ = self.dec_type(pos_ze) if self.etypes > 1 else None

        return x_, edge_pred, t_, p_

    def sample(self, z, label):
        z += self.map_label_d(label)
        x_ = self.dec_attr(z)

        z1 = z.unsqueeze(1).expand(-1, z.size(0), -1)
        z2 = z.unsqueeze(0).expand(z.size(0), -1, -1)
        ze = torch.cat((z1, z2), dim=2)

        adj = torch.sigmoid(self.dec_stru(ze)).squeeze(-1)
        edge_index = (adj > self.threshold).nonzero().T
        edge_index = add_self_loops(edge_index, num_nodes=z.size(0))[0]

        pos_ze = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=1)
        if self.temporal:
            t_ = self.dec_time(pos_ze).squeeze(-1)
            t_ = torch.clamp(t_, min=0, max=1)
            t_ = t_ * (self.t_max - self.t_min)
        else:
            t_ = None
        p_ = self.dec_type(pos_ze).argmax(-1) if self.etypes > 1 else None

        return x_, edge_index, t_, p_


class GOCM_Consistency_Cluster(BaseTransform):
    
    def __init__(self,
                 name="",
                 hid_dim=None,
                 cons_dim=None,
                 vae_epochs=50,
                 cons_epochs=50,
                 patience=50,
                 lr=0.001,
                 wd=0.,
                 batch_size=4096,
                 threshold=0.75,
                 wx=1.,
                 we=.5,
                 beta=1e-3,
                 wt=1.,
                 time_attr='edge_time',
                 type_attr='edge_type',
                 wp=.3,
                 gen_nodes=None,
                 sample_steps=1,
                 device=0,
                 verbose=False,
                 encoder_backend: str = 'sage',
                 reuse_ae: bool = False,
                 reuse_cm: bool = False,
                 neg_resample_epochs: int = 5,
                 gen_ratio: float = 1.0,
                 cached_neg: bool = False,
                 
                 mc_schedule: str = 'linear',
                 mc_eta: float = 0.0,
                 mc_s_min: float = 0.002,
                 mc_step_clip: float = None,
                 mc_rho: float = 7.0,
                 mc_heun: bool = False,
                 mc_single_use_s_min: bool = False):

        self.name = name
        self.hid_dim = hid_dim
        self.cons_dim = cons_dim
        self.vae_epochs = vae_epochs
        self.cons_epochs = cons_epochs
        self.patience = patience
        self.lr = lr
        self.wd = wd
        self.batch_size = batch_size
        self.threshold = threshold
        self.wx = wx
        self.we = we
        self.beta = beta
        self.time_attr = time_attr
        self.temporal = True
        self.wt = wt
        self.type_attr = type_attr
        self.etypes = 1
        self.wp = wp
        self.gen_nodes = gen_nodes
        self.sample_steps = sample_steps
        self.device = pygod.utils.validate_device(device)
        self.verbose = verbose
        self.encoder_backend = encoder_backend
        self.reuse_ae = reuse_ae
        self.reuse_cm = reuse_cm
        self.neg_resample_epochs = max(1, int(neg_resample_epochs))
        self._neg_cache = {}
        self.gen_ratio = float(gen_ratio)
        self.cached_neg = bool(cached_neg)

        
        self.mc_schedule = str(mc_schedule)
        self.mc_eta = float(mc_eta)
        self.mc_s_min = float(mc_s_min)
        self.mc_step_clip = mc_step_clip if (mc_step_clip is None) else float(mc_step_clip)
        self.mc_rho = float(mc_rho)
        self.mc_heun = bool(mc_heun)
        self.mc_single_use_s_min = bool(mc_single_use_s_min)

        self.ae = None
        self.cm = None

        self.y0 = None
        self.emap = None
        self.mean = None
        self.std = None
        self.t_min = None
        self.t_max = None
        self.last_gen_time = 0.0

    def forward(self, data):
        self.arg_parse(data)
        data = self.preprocess(data)

        
        num_parts = max(1, data.num_nodes // self.batch_size)
        save_dir = os.path.join('ckpt', 'cluster_cache', f'{self.name}_parts{num_parts}')
        os.makedirs(save_dir, exist_ok=True)
        if self.verbose:
            print(f'[GOCM_Consistency_Cluster] : nodes={data.num_nodes}, num_parts={num_parts}', flush=True)
        
        cluster_data = ClusterData(data,
                                   num_parts=num_parts,
                                   log=self.verbose,
                                   save_dir=save_dir)
        dataloader = ClusterLoader(
            cluster_data,
            batch_size=2,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=1
        )

        if not os.path.exists('ckpt'):
            os.mkdir('ckpt')

        
        self.ae = VGAE_Cluster(
            data.num_node_features,
            self.hid_dim,
            temporal=self.temporal,
            t_min=self.t_min,
            t_max=self.t_max,
            etypes=self.etypes,
            threshold=self.threshold,
            encoder_backend=self.encoder_backend
        ).to(self.device)

        
        ae_ckpt = f"ckpt/{self.name}_{self.encoder_backend}_ae_consistency_cluster.pt"
        if self.reuse_ae and os.path.exists(ae_ckpt):
            if self.verbose:
                print(f"[Reuse] load VGAE: {ae_ckpt}")
            try:
                self.ae = torch.load(ae_ckpt, map_location=self.device)
                self.ae = self.ae.to(self.device)
            except Exception as e:
                print(f"[Reuse] load VGAE failed, will train VGAE. : {e}")
                self.train_ae(dataloader)
        else:
            self.train_ae(dataloader)
        
        
        self.cm = ConsistencyCluster(
            d_in=self.hid_dim,
            dim_t=self.cons_dim,
            device=self.device
        ).to(self.device)

        
        cm_ckpt = f"ckpt/{self.name}_consistency_cluster.pt"
        if self.reuse_cm and os.path.exists(cm_ckpt):
            if self.verbose:
                print(f"[Reuse] load Consistency: {cm_ckpt}")
            try:
                self.cm = torch.load(cm_ckpt, map_location=self.device)
                self.cm = self.cm.to(self.device)
            except Exception as e:

                print(f"[Reuse] load Consistency failed, will train Consistency. : {e}")
                self.train_cm(dataloader)
        else:
            self.train_cm(dataloader)

        
        gen_start_time = time.time()
        gen_gs = []
        gen_nodes = self.gen_nodes
        chunks = []
        while gen_nodes > self.batch_size:
            chunks.append(self.batch_size)
            gen_nodes -= self.batch_size
        if gen_nodes > 0:
            chunks.append(gen_nodes)

        
        if isinstance(self.device, torch.device) and self.device.type == 'cuda' and len(chunks) > 1:
            max_streams = min(2, len(chunks))
            for start in range(0, len(chunks), max_streams):
                group = chunks[start:start + max_streams]
                streams = [torch.cuda.Stream(device=self.device) for _ in range(len(group))]
                tmp_list = [None] * len(group)
                for idx, (csize, stream) in enumerate(zip(group, streams)):
                    with torch.cuda.stream(stream):
                        tmp_list[idx] = self.sample(self.cm, csize)
                torch.cuda.synchronize(self.device)
                gen_gs.extend(tmp_list)
        else:
            for csize in chunks:
                gen_gs.append(self.sample(self.cm, csize))
        self.last_gen_time = time.time() - gen_start_time

        
        data.y = self.y0

        
        aug_data = Batch.from_data_list([data] + gen_gs)

        aug_data = self.postprocess(aug_data)
        save_path = 'ckpt/' + self.name + '_aug_data.pt'
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(aug_data, save_path)

        return aug_data

    def arg_parse(self, data):
        if not isinstance(data, Data):
            raise TypeError('data must be  .data.Data')

        if not hasattr(data, 'x'):
            raise ValueError('data must have feature x')

        if not hasattr(data, 'y'):
            raise ValueError('data must have label y')

        if not hasattr(data, 'train_mask') or not hasattr(data, 'val_mask') \
                or not hasattr(data, 'test_mask'):
            raise ValueError('data must have train_mask, val_mask, test_mask')

        if self.hid_dim is None:
            self.hid_dim = 2 ** int(math.log2(data.x.size(1)) - 1)
        if self.cons_dim is None:
            self.cons_dim = 2 * self.hid_dim

        # mask out val and test nodes in training
        self.y0 = copy.deepcopy(data.y)
        data.y[data.train_mask == 0] = 0

        if not hasattr(data, self.time_attr):
            self.temporal = False
            self.wt = 0.

        if hasattr(data, self.type_attr):
            self.etypes = getattr(data, self.type_attr).unique().size(0)

        if self.gen_nodes is None:
            base = data.y[data.train_mask].sum()
            self.gen_nodes = int(max(1, float(base) * max(0.0, self.gen_ratio)))

    def preprocess(self, data):
        # to undirected
        if data.is_directed():
            data = ToUndirected(reduce='min')(data)

        # normalize node feature
        self.mean, self.std = data.x.mean(0), data.x.std(0)
        std_safe = self.std.clone()
        if (std_safe == 0).any():
            std_safe[std_safe == 0] = 1.0
        data.x = (data.x - self.mean) / std_safe
        self.std = std_safe

        # reindex edge type
        if self.etypes > 1:
            edge_type = getattr(data, self.type_attr)
            self.emap, edge_type = edge_type.unique(return_inverse=True)
            setattr(data, self.type_attr, edge_type)

        # time range
        if self.temporal:
            edge_time = getattr(data, self.time_attr)
            self.t_min, self.t_max = edge_time.min(), edge_time.max()
            setattr(data, self.time_attr, (edge_time - self.t_min))

        return data

    def postprocess(self, data):
        # denormalize
        data.x = data.x * self.std + self.mean

        # recover edge type
        if self.etypes > 1:
            edge_type = getattr(data, self.type_attr)
            setattr(data, self.type_attr, self.emap[edge_type])

        # recover time
        if self.temporal:
            edge_time = getattr(data, self.time_attr)
            setattr(data, self.time_attr, edge_time + self.t_min)

        return data

    def train_ae(self, dataloader):
        
        if self.verbose:
            print('train autoencoder...')
        optimizer = torch.optim.Adam(self.ae.parameters(),
                                     lr=self.lr,
                                     weight_decay=self.wd)

        best_loss = float('inf')
        patience = 0
        for epoch in range(self.vae_epochs):
            
            if epoch % self.neg_resample_epochs == 0:
                self._neg_cache.clear()
            start = time.time()
            self.ae.train()
            total_loss = 0
            num_nodes = 0
            for batch_idx, batch in enumerate(dataloader):
                batch_size = batch.x.size(0)
                x = batch.x.to(self.device, non_blocking=True)
                pos_edge_index = batch.edge_index.to(self.device, non_blocking=True)
                
                
                if self.cached_neg:
                    if (epoch % self.neg_resample_epochs == 0) or (batch_idx not in self._neg_cache):
                        neg_edge_index = negative_sampling(
                            edge_index=pos_edge_index,
                            num_nodes=batch_size)
                        self._neg_cache[batch_idx] = neg_edge_index.cpu()
                    cached = self._neg_cache.get(batch_idx)
                    if cached is None or cached.max() >= batch_size:
                        neg_edge_index = negative_sampling(
                            edge_index=pos_edge_index,
                            num_nodes=batch_size)
                        self._neg_cache[batch_idx] = neg_edge_index.cpu()
                    neg_edge_index = self._neg_cache[batch_idx].to(self.device, non_blocking=True)
                else:
                    neg_edge_index = negative_sampling(
                        edge_index=pos_edge_index,
                        num_nodes=batch_size)

                edge_label = torch.cat([torch.ones_like(pos_edge_index[0]),
                                        torch.zeros_like(neg_edge_index[0])])

                y = batch.y.float().unsqueeze(1).to(self.device, non_blocking=True)

                t = getattr(batch, self.time_attr).to(self.device, non_blocking=True) \
                    if self.temporal else None
                p = getattr(batch, self.type_attr).to(self.device, non_blocking=True) \
                    if self.etypes > 1 else None

                x_, edge_pred, t_, p_ = self.ae(x, pos_edge_index,
                                                neg_edge_index, y, t, p)
                loss = self.recon_loss(x, x_, edge_label.float(), edge_pred,
                                       t, t_, p, p_)
                
                
                if torch.isnan(loss) or torch.isinf(loss):
                    if self.verbose:
                        print("find NaN/Inf loss, skip this batch.")
                    optimizer.zero_grad()
                    continue

                kl_div = (0.5 / x_.size(0) *
                          (1 + 2 * self.ae.log_std - self.ae.mean ** 2 -
                           torch.exp(self.ae.log_std) ** 2).sum(1).mean())
                loss -= self.beta * kl_div

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0)
                optimizer.step()

                total_loss += loss.item() * batch_size
                num_nodes += batch_size

            curr_loss = total_loss / num_nodes

            if curr_loss < best_loss:
                best_loss = curr_loss
                patience = 0
                save_path = f"ckpt/{self.name}_{self.encoder_backend}_ae_consistency_cluster.pt"
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                torch.save(self.ae, save_path)
            else:
                patience += 1
                if patience == self.patience:
                    if self.verbose:
                        print('early stop')
                    break

            epoch_t = time.time() - start
            if self.verbose:
                print(f'epoch: {epoch:03d}, loss: {curr_loss:.6f}, '
                      f'cost time: {epoch_t:.4f}')

        self.ae = torch.load(f"ckpt/{self.name}_{self.encoder_backend}_ae_consistency_cluster.pt")

    def train_cm(self, dataloader):
        """Train Consistency Model (Cluster Version)"""
        if self.verbose:
            print('train Consistency Model...')
        optimizer = torch.optim.Adam(self.cm.parameters(), lr=self.lr,
                                     weight_decay=self.wd)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.9,
                                      patience=20, verbose=self.verbose)

        self.cm.train()
        best_loss = float('inf')
        patience = 0
        for epoch in range(self.cons_epochs):
            pbar = tqdm.tqdm(dataloader, total=len(dataloader),
                             disable=not self.verbose)
            pbar.set_description(f"epoch {epoch}")

            batch_loss = 0.0
            len_input = 0
            for batch_idx, batch in enumerate(pbar):
                x = batch.x.to(self.device, non_blocking=True)
                edge_index = batch.edge_index.to(self.device, non_blocking=True)
                y = batch.y.float().unsqueeze(1).to(self.device, non_blocking=True)
                t = getattr(batch, self.time_attr).to(self.device, non_blocking=True) \
                    if self.temporal else None
                p = getattr(batch, self.type_attr).to(self.device, non_blocking=True) \
                    if self.etypes > 1 else None

                
                inputs, _ = self.ae.encode(x, edge_index, y, t, p)
                
                loss = self.cm(inputs, y)
                
                if torch.isnan(loss) or torch.isinf(loss):
                    if self.verbose:
                        print("find NaN/Inf loss (Consistency), skip this batch.")
                    optimizer.zero_grad()
                    continue

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.cm.parameters(), 1.0)
                optimizer.step()

                batch_loss += loss.item() * len(inputs)
                len_input += len(inputs)

                pbar.set_postfix({"loss": loss.item()})

            curr_loss = batch_loss / len_input
            scheduler.step(curr_loss)

            if curr_loss < best_loss:
                best_loss = curr_loss
                patience = 0
                save_path = 'ckpt/' + self.name + '_consistency_cluster.pt'
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                torch.save(self.cm, save_path)
            else:
                patience += 1
                if patience == self.patience:
                    if self.verbose:
                        print('early stop')
                    break

        self.cm = torch.load('ckpt/' + self.name + '_consistency_cluster.pt')

    def sample(self, model, graph_size):
        
        noise = torch.randn(graph_size, self.hid_dim).to(self.device)
        label = torch.ones(graph_size).unsqueeze(1).to(self.device)

        
        if self.sample_steps > 0:
            z = consistency_sampling(
                model.model, noise, label,
                num_steps=self.sample_steps,
                schedule=self.mc_schedule,
                eta=self.mc_eta,
                s_min=self.mc_s_min,
                single_use_s_min=self.mc_single_use_s_min,
                step_clip=self.mc_step_clip,
                rho=self.mc_rho,
                heun=self.mc_heun
            )
        else:
            z = noise
            
        
        x_, edge_index, t_, p_ = self.ae.sample(z, label)

        data = Data(x=x_, edge_index=edge_index,
                    y=label.squeeze().long(), edge_time=t_,
                    edge_type=p_).cpu().detach()
        if self.temporal:
            setattr(data, self.time_attr, t_.cpu().detach())
        if self.etypes > 1:
            setattr(data, self.type_attr, p_.cpu().detach())
        data.train_mask = torch.ones(data.num_nodes, dtype=torch.bool)
        data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        return data

    def recon_loss(self, x, x_, edge_label, edge_pred,
                   t=None, t_=None, p=None, p_=None):
        
        if torch.isnan(x_).any() or torch.isinf(x_).any():
            x_ = torch.nan_to_num(x_, nan=0.0, posinf=1e6, neginf=-1e6)
        loss_x = F.mse_loss(x_, x)
        if self.verbose:
            print("    epoch loss: feature: {:.4f}".format(loss_x.item()),
                  end=' ')
        if torch.isnan(edge_pred).any() or torch.isinf(edge_pred).any():
            edge_pred = torch.nan_to_num(edge_pred, nan=0.0, posinf=50.0, neginf=-50.0)
        loss_e = F.binary_cross_entropy_with_logits(edge_pred, edge_label)
        if self.verbose:
            print("edge: {:.4f}".format(loss_e.item()), end=' ')
        if self.temporal:
            denom = (self.t_max - self.t_min)
            if denom == 0:
                denom = 1.0
            if t_ is not None:
                loss_t = F.mse_loss(t_, t / denom)
            else:
                loss_t = 0
            if self.verbose:
                print("time: {:.4f}".format(loss_t.item() if isinstance(loss_t, torch.Tensor) else loss_t), end=' ')
        else:
            loss_t = 0
        if self.etypes > 1:
            loss_p = F.cross_entropy(p_, p)
            if self.verbose:
                print("type: {:.4f}".format(loss_p.item()), end=' ')
        else:
            loss_p = 0
        loss = (self.wx * loss_x + self.we * loss_e +
                self.wt * loss_t + self.wp * loss_p)
        if self.verbose:
            print()
        return loss


GOCM = GOCM_Consistency_Cluster