import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as D
import random


## For dataset, auxiliary functions and code for the benchmarks: https://github.com/BestActionNow/EGMN/tree/main


try:
    from diagnosis import debug_model_outputs
except ImportError:
    def debug_model_outputs(*args, **kwargs): pass

try:
    from utils import eval_mae, eval_xauc
    from dataloader import KUAIRECDataLoader
except ImportError:
    def eval_mae(y, y_pred): return np.mean(np.abs(y - y_pred))
    def eval_xauc(y, y_pred): return 0.5

def eval_kl(samples_p, samples_q, bins=100, epsilon=1e-10):
    if samples_p.max() == samples_p.min():
        return 0.0
    if samples_q.max() == samples_q.min():
        samples_q = np.array([samples_q.mean()]) 

    hist_p, bin_edges = np.histogram(samples_p, bins=bins, density=True)
    hist_q, _ = np.histogram(samples_q, bins=bin_edges, density=True)
    
    bin_width = np.diff(bin_edges)
    hist_p = hist_p * bin_width
    hist_q = hist_q * bin_width
    
    hist_p = np.clip(hist_p, epsilon, None)
    hist_q = np.clip(hist_q, epsilon, None)
    
    kl = np.sum(hist_p * np.log(hist_p / hist_q))
    return kl

def mae_rescale_to_second(dataset, mae):
    if dataset == 'kuairec':
        return mae * 999639 / 1000
    return mae

def estimate_duration_bounds(dataloader, device, margin=0.05):
    ys = []
    for batch in dataloader:
        if isinstance(batch, dict):
            ys.append(batch['duration'])
        else:
            _, y = batch
            ys.append(y)
    ys = torch.cat(ys).to(device)
    y_min = ys.min().item()
    y_max = ys.max().item()
    span = y_max - y_min
    t_min = max(1e-4, y_min - margin * span)
    t_max = y_max + margin * span
    return t_min, t_max

def get_quantile_threshold(dataloader, quantile=0.4):
    all_y = []
    for batch in dataloader:
        if isinstance(batch, dict):
            y = batch['duration']
        else:
            _, y = batch
        all_y.append(y)
    all_y = torch.cat(all_y).float()
    threshold = torch.quantile(all_y, quantile).item()
    return threshold

def evaluate(model, loaders, device, t_min, t_max, dataset_name, threshold):
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for batch in loaders['test']:
            if isinstance(batch, dict):
                x = {k: v.to(device) for k, v in batch.items() if k != 'duration'}
                y = batch['duration'].float()
            else:
                x, y = batch
                x = {k: v.to(device) for k, v in x.items()}
                y = y.float()

            p = model.predict(x, t_min, t_max, switch_threshold=threshold).cpu()
            y_true.extend(y.cpu().numpy())
            y_pred.extend(p.numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    mae = mae_rescale_to_second(dataset_name, eval_mae(y_true, y_pred))
    xauc = eval_xauc(y_true, y_pred)
    kl = eval_kl(y_true.flatten(), y_pred.flatten())
    
    return mae, xauc, kl

class MultiLayerPerceptron(nn.Module):
    def __init__(self, input_dim, dims, dropout, output_layer=True):
        super().__init__()
        layers = []
        for d in dims:
            layers += [
                nn.Linear(input_dim, d),
                nn.LayerNorm(d), 
                nn.ReLU(),
                nn.Dropout(dropout)
            ]
            input_dim = d
        if output_layer:
            layers.append(nn.Linear(input_dim, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class EGMN_Head(nn.Module):
    def __init__(self, hidden_dim, output_dim, dropout=0.0):
        super().__init__()
        self.lambda_layer = nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.Softplus(beta=0.5)
        )
        comp_num = 10 
        self.mixture_logits = nn.Linear(hidden_dim, comp_num + 1)
        self.gauss_mu = nn.Sequential(
            nn.Linear(hidden_dim, comp_num),
            nn.Softplus()
        )
        self.gauss_sigma = nn.Sequential(
            nn.Linear(hidden_dim, comp_num),
            nn.Softplus()
        )

    def forward(self, hidden):
        lambda_ = self.lambda_layer(hidden) + 1e-6
        pi = self.mixture_logits(hidden)
        mu = self.gauss_mu(hidden) + 1.0/lambda_ 
        sigma = self.gauss_sigma(hidden) + 1e-6
        return pi, lambda_, mu, sigma

    def loss(self, y_true, pi, lambda_, mu, sigma):
        batch_size = y_true.shape[0]
        y_true = y_true.view(-1, 1)

        exp_dist = D.Exponential(rate=lambda_.view(-1))
        log_prob_short = exp_dist.log_prob(y_true.view(-1)).view(batch_size, 1)

        log_prob_all = []
        for comp_idx in range(mu.shape[1]):
            normal_dist = D.Normal(loc=mu[:, comp_idx], scale=sigma[:, comp_idx])
            trunc_min = torch.zeros_like(mu[:, comp_idx])
            prob_long = 1.0 - normal_dist.cdf(trunc_min)
            log_prob = normal_dist.log_prob(y_true.view(-1)) - torch.log(prob_long + 1e-6)
            log_prob_all.append(log_prob.view(-1, 1))
        
        log_prob_all = torch.concat([log_prob_short] + log_prob_all, dim=1)
        log_mix_probs = torch.log_softmax(pi, dim=1)
        total_log_prob = torch.logsumexp(log_mix_probs + log_prob_all, dim=1, keepdim=True)
        
        return -torch.mean(total_log_prob)

    def predict_expectation(self, hidden):
        pi, lambda_, mu, sigma = self.forward(hidden)
        pi = torch.softmax(pi, dim=1)
        components = torch.concat([1.0/lambda_, mu], dim=1)
        return torch.sum(pi * components, dim=1)

class ExtDDM_Head(nn.Module):
    def __init__(self, hidden_dim, mu_dims, u_dims, dropout=0.0):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.u_mlp = nn.Sequential(
            MultiLayerPerceptron(hidden_dim, u_dims, dropout),
            nn.Softplus()
        )
        self.u_bias = nn.Parameter(torch.tensor(1.0))
        
        self.mu_mlp = nn.Sequential(
            MultiLayerPerceptron(3 * hidden_dim + 13, mu_dims, dropout),
            nn.Softplus() 
        )

    def _time_features(self, t):
        t = torch.clamp(t, min=1e-5)
        return torch.cat([
            torch.log(t), 
            t, 
            t**2, 
            t**3,
            torch.sqrt(t),
            1.0 / (t + 0.01),          
            1.0 / (t**2),     
            torch.cos(t),
            torch.sin(t),
            torch.atan(t),
            torch.exp(-1.0 * t),    
            torch.exp(-2.0 * t),
            torch.exp(torch.clamp(t, max=10.0))
        ], dim=1)

    def get_drift_trajectory(self, hidden, t_grid):
        B, T = t_grid.shape
        h_exp = hidden.unsqueeze(1).repeat(1, T, 1).view(-1, self.hidden_dim)
        
        t_flat = t_grid.view(-1, 1)
        
        time_feats = self._time_features(t_flat)
        
        h_times_t = h_exp * t_flat
        h_div_t = h_exp / (t_flat + 0.0001)

        drift_in = torch.cat([h_exp, time_feats, h_times_t, h_div_t], dim=1)
        
        mu = self.mu_mlp(drift_in).view(B, T)
        return mu

    def get_threshold(self, hidden):
        return self.u_mlp(hidden) + self.u_bias

    def _calculate_pdf(self, hidden, t_val):
        u = self.get_threshold(hidden)
        n_steps = 30
        steps = torch.linspace(0, 1, n_steps, device=t_val.device).view(1, -1) ** 2
        time_grid = t_val.view(-1, 1) * steps
        
        h_expanded = hidden.unsqueeze(1).repeat(1, n_steps, 1).view(-1, self.hidden_dim)
        
        t_flat = time_grid.view(-1, 1)
        
        time_feats = self._time_features(t_flat)
        
        h_times_t = h_expanded * t_flat
        h_div_t = h_expanded / (t_flat + 0.0001)
        
        drift_input = torch.cat([h_expanded, time_feats, h_times_t, h_div_t], dim=1)
        
        mu = self.mu_mlp(drift_input).view(-1, n_steps)
        
        integral = torch.trapz(mu, time_grid, dim=1).view(-1, 1)
        mu_t = mu[:, -1].view(-1, 1)
        
        sqrt_t = torch.sqrt(t_val + 1e-8)
        f = (u - integral) / sqrt_t
        
        phi = torch.exp(-0.5 * f**2) / np.sqrt(2 * np.pi)
        term = (u - integral)/(t_val + 1e-8) + mu_t
        pdf = F.relu(term / sqrt_t) * phi
        
        return pdf

    def predict_expectation(self, hidden, t_min, t_max):
        n_steps = 50
        t_grid = torch.linspace(t_min, t_max, n_steps, device=hidden.device).view(1, -1)
        dt = (t_max - t_min) / (n_steps - 1)
        
        pdfs = []
        for i in range(n_steps):
            t_i = t_grid[:, i:i+1].repeat(hidden.size(0), 1)
            pdf_i = self._calculate_pdf(hidden, t_i)
            pdfs.append(pdf_i)
            
        pdf = torch.cat(pdfs, dim=1)
        norm = torch.sum(pdf * dt, dim=1, keepdim=True) + 1e-8
        pdf = pdf / norm
        return torch.sum(t_grid * pdf * dt, dim=1)

class HybridModel(nn.Module):
    def __init__(self, description, embed_dim, share_dims, dropout=0.0):
        super().__init__()
        self.scale_factor = 100.0 
        
        self.features = {n: (s, t) for n, s, t in description if t in ['ctn', 'spr', 'seq']}
        self.emb_layer = nn.ModuleDict()
        self.ctn_layer = nn.ModuleDict()
        embed_out = 0
        for name, (size, t) in self.features.items():
            if t == 'spr':
                self.emb_layer[name] = nn.Embedding(size, embed_dim)
                embed_out += embed_dim
            elif t == 'ctn':
                self.ctn_layer[name] = nn.Linear(1, 1, bias=False)
                embed_out += 1
            elif t == 'seq':
                self.emb_layer[name] = nn.Embedding(size, embed_dim)
                embed_out += embed_dim

        self.encoder = MultiLayerPerceptron(embed_out, share_dims, dropout, False)
        self.hidden_dim = share_dims[-1]

        self.egmn = EGMN_Head(self.hidden_dim, output_dim=1)
        self.ddm = ExtDDM_Head(self.hidden_dim, [64, 32], [64, 32])

    def load_pretrained_egmn(self, path, device):
        try:
            checkpoint = torch.load(path, map_location=device)
            state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint

            new_state_dict = {}
            for key, value in state_dict.items():
                if key.startswith('share_mlp'):
                    new_key = key.replace('share_mlp', 'encoder.net')
                    if 'share_mlp.net' in key: new_key = key.replace('share_mlp', 'encoder')
                    new_state_dict[new_key] = value
                elif key.startswith('ctn_linear_layer'):
                    new_key = key.replace('ctn_linear_layer', 'ctn_layer')
                    new_state_dict[new_key] = value
                elif key.startswith('emb_layer'):
                    new_state_dict[key] = value
                elif any(k in key for k in ['lambda_layer', 'mixture_logits', 'gauss_mu', 'gauss_sigma']):
                    new_key = f"egmn.{key}"
                    new_state_dict[new_key] = value
                
            missing, _ = self.load_state_dict(new_state_dict, strict=False)
            
        except Exception as e:
            pass

    def get_features(self, x):
        embs, lins = [], []
        for name, (_, t) in self.features.items():
            v = x[name]
            if t == 'spr':
                embs.append(self.emb_layer[name](v).squeeze(1))
            elif t == 'ctn':
                lins.append(self.ctn_layer[name](v))
            elif t == 'seq':
                emb = self.emb_layer[name](v)
                mask = x[f"{name}mask"].unsqueeze(2)
                embs.append((emb * mask).sum(1) / (mask.sum(1) + 1e-6))
        return self.encoder(torch.cat(embs + lins, dim=1))

    def forward(self, x, y, training_threshold=0.01):
        hidden = self.get_features(x)
        y_scaled = y * self.scale_factor
        
        pi, lambda_, mu, sigma = self.egmn(hidden)
        loss_egmn = self.egmn.loss(y_scaled, pi, lambda_, mu, sigma)
        
        mask_long = (y > training_threshold).view(-1)
        
        if mask_long.sum() > 0:
            hidden_long = hidden[mask_long]
            y_long_scaled = y_scaled[mask_long].view(-1, 1)
            pdf = self.ddm._calculate_pdf(hidden_long, y_long_scaled)
            loss_ddm = -torch.log(pdf + 1e-8).mean()
        else:
            loss_ddm = torch.tensor(0.0, device=y.device)
            
        return loss_egmn, loss_ddm

    def predict(self, x, t_min, t_max, switch_threshold=0.01):
        self.eval()
        with torch.no_grad():
            hidden = self.get_features(x)
            
            pred_egmn_scaled = self.egmn.predict_expectation(hidden)
            switch_thresh_scaled = switch_threshold * self.scale_factor
            is_long = pred_egmn_scaled > switch_thresh_scaled
            
            final_preds_scaled = pred_egmn_scaled.clone()
            
            if is_long.sum() > 0:
                hidden_long = hidden[is_long]
                t_min_scaled = t_min * self.scale_factor
                t_max_scaled = t_max * self.scale_factor
                pred_ddm_scaled = self.ddm.predict_expectation(hidden_long, t_min_scaled, t_max_scaled)
                final_preds_scaled[is_long] = pred_ddm_scaled
            
            return final_preds_scaled / self.scale_factor

def get_args():
    p = argparse.ArgumentParser()
    p.add_argument('--dataset_name', default='kuairec')
    p.add_argument('--dataset_path', default='./dataset/')
    p.add_argument('--epoch', type=int, default=20)
    p.add_argument('--bsz', type=int, default=512)
    p.add_argument('--lr', type=float, default=0.00005)
    p.add_argument('--share_dims', type=int, nargs='+', default=[128, 64])
    p.add_argument('--egmn_w', type=float, default=1.0)
    p.add_argument('--ddm_w', type=float, default=1.0)
    p.add_argument('--threshold', type=float, default=1.0) 
    p.add_argument('--quantile', type=float, default=0.40)
    p.add_argument('--save_dir', default='./checkpoints')
    p.add_argument('--pretrained_egmn', type=str, default=None, help='Path to egmn_model.pt')
    
    p.add_argument('--freeze_egmn', action='store_true', help='Freeze EGMN and Encoder weights')
    p.add_argument('--egmn_lr', type=float, default=0.00001, help='Specific LR for EGMN/Encoder. If None, uses --lr')
    
    return p.parse_args()

if __name__ == '__main__':
    args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    loaders = KUAIRECDataLoader(
        args.dataset_name,
        os.path.join(args.dataset_path, args.dataset_name, f"{args.dataset_name}_data.pkl"),
        device,
        bsz=args.bsz
    )
    
    t_min, t_max = estimate_duration_bounds(loaders['train'], device)
    
    args.threshold = get_quantile_threshold(loaders['train'], quantile=args.quantile)
    
    model = HybridModel(
        loaders.description, 
        embed_dim=16, 
        share_dims=args.share_dims
    ).to(device)
    
    if args.pretrained_egmn and os.path.exists(args.pretrained_egmn):
        model.load_pretrained_egmn(args.pretrained_egmn, device)
    
    egmn_params = list(model.egmn.parameters()) + list(model.encoder.parameters())
    ddm_params = list(model.ddm.parameters())
    
    if args.freeze_egmn:
        for p in egmn_params:
            p.requires_grad = False
        optimizer = optim.Adam(ddm_params, lr=args.lr, weight_decay=1e-5)
    else:
        lr_egmn = args.egmn_lr if args.egmn_lr is not None else args.lr
        
        optimizer = optim.Adam([
            {'params': ddm_params, 'lr': args.lr},
            {'params': egmn_params, 'lr': lr_egmn}
        ], weight_decay=1e-5)

    os.makedirs(args.save_dir, exist_ok=True)
    
    for epoch in range(1, args.epoch + 1):
        model.train()
        total_loss, steps = 0.0, 0
        l_egmn_acc, l_ddm_acc = 0.0, 0.0

        for batch in loaders['train']:
            if isinstance(batch, dict):
                x = {k: v.to(device) for k, v in batch.items() if k != 'duration'}
                y = batch['duration'].to(device).float()
            else:
                x, y = batch
                x = {k: v.to(device) for k, v in x.items()}
                y = y.to(device).float()

            le, ld = model(x, y, training_threshold=args.threshold)
            loss = args.egmn_w * le + args.ddm_w * ld
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()
            l_egmn_acc += le.item()
            l_ddm_acc += ld.item()
            steps += 1

        print(f"Epoch {epoch:03d} | Loss: {total_loss/steps:.4f} (EGMN: {l_egmn_acc/steps:.4f}, DDM: {l_ddm_acc/steps:.4f})")

        save_path = os.path.join(args.save_dir, f"hybrid_model_epoch_{epoch}.pt")
        torch.save(model.state_dict(), save_path)

    mae, xauc, kl = evaluate(model, loaders, device, t_min, t_max, args.dataset_name, args.threshold)
    print("\nFINAL RESULTS")
    print(f"MAE:  {mae:.4f}")
    print(f"XAUC: {xauc:.4f}")
    print(f"KL:   {kl:.4f}")
