import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import json
import logging

logger = logging.getLogger(__name__)

class ExperimentDiagnostics:
    def __init__(self, cfg, save_dir, add_time_bool):
        self.cfg = cfg
        self.save_dir = save_dir
        self.add_time = add_time_bool
        os.makedirs(self.save_dir, exist_ok=True)

    def run_suite(self, model, train_loader, device):
        logger.info(">>> RUNNING DIAGNOSTICS SUITE (plot_data=True)")
        
        try:
            batch_coeffs, batch_y = next(iter(train_loader))
        except StopIteration:
            return

        batch_coeffs = batch_coeffs.to(device)
        
        self._analyze_inputs(batch_coeffs, model)
        
        self._analyze_model_mechanics(model, batch_coeffs)
        
        self._analyze_hidden_states(model, batch_coeffs)
        
        logger.info(f">>> Diagnostics complete. Results saved to {self.save_dir}")

    def _save_json(self, data, filename):
        path = os.path.join(self.save_dir, filename)
        def converter(obj):
            if isinstance(obj, torch.Tensor):
                return obj.detach().cpu().tolist()
            if isinstance(obj, np.ndarray):
                return obj.tolist()
            if isinstance(obj, (np.float32, np.float64)):
                return float(obj)
            return str(obj)
            
        with open(path, 'w') as f:
            json.dump(data, f, indent=4, default=converter)

    def _analyze_inputs(self, coeffs, model):
        stats = {}
        
        t_grid = getattr(model, 't_grid', None)
        
        if t_grid is not None:
            t_channel = t_grid
            feats = coeffs
            stats['mode'] = 'separated_time'
        else:
            t_channel = coeffs[:, :, 0]
            feats = coeffs[:, :, 1:]
            stats['mode'] = 'embedded_time'
        
        stats['dX_shape'] = list(coeffs.shape)
        
        t_flat = t_channel if t_channel.dim() == 1 else t_channel.flatten()
        
        stats['time_channel'] = {
            'min': t_flat.min().item(), 'max': t_flat.max().item(),
            'mean': t_flat.mean().item(), 'std': t_flat.std().item()
        }
        stats['features'] = {
            'min': feats.min().item(), 'max': feats.max().item(),
            'mean': feats.mean().item(), 'std': feats.std().item()
        }

        checks = {'time_monotonic': True, 'time_normalized': False}
        
        t_sample = t_channel if t_channel.dim() == 1 else t_channel[0]
        diffs = t_sample[1:] - t_sample[:-1]
        if (diffs < -1e-5).any():
            checks['time_monotonic'] = False
            logger.warning("DIAGNOSTICS: Time channel is NOT monotonic!")
            
        if t_sample.max() <= 1.0 and t_sample.min() >= 0.0:
            checks['time_normalized'] = True
        
        stats['checks'] = checks
        self._save_json(stats, 'input_stats.json')

        num_rows = min(20, coeffs.size(1))
        sample_np = coeffs[0, :num_rows, :].detach().cpu().numpy()
        cols = [f"Ch_{i}" for i in range(sample_np.shape[1])]
        
        if t_grid is None:
             if cols: cols[0] = "Time_Input"
        
        df_raw = pd.DataFrame(sample_np, columns=cols)
        
        if t_grid is not None:
             t_np = t_grid[:num_rows].detach().cpu().numpy()
             df_raw.insert(0, "Time_Grid", t_np)
             
        df_raw.to_csv(os.path.join(self.save_dir, "input_raw_sample.csv"), index=False)

    def _analyze_model_mechanics(self, model, coeffs):
        if not hasattr(model, 'make_interpolation'):
            logger.warning("Model does not support 'make_interpolation'. Skipping mechanics analysis.")
            return

        X_interp = model.make_interpolation(coeffs)
        t_start = X_interp.interval[0].item()
        t_end = X_interp.interval[1].item()
        
        if hasattr(X_interp, 'grid_points'):
             t_eval = X_interp.grid_points[:20]
        else:
             t_eval = torch.linspace(t_start, t_end, steps=20, device=coeffs.device)
        
        derivs_list = []
        for t in t_eval:
            d = X_interp.derivative(t)
            derivs_list.append(d[0].detach().cpu().numpy())
            
        derivs_np = np.array(derivs_list)
        d_cols = [f"dX_dt_Ch_{i}" for i in range(derivs_np.shape[1])]
        if d_cols and self.add_time:
            d_cols[0] = "dX_dt_Time_Should_Be_1"
            
        df_deriv = pd.DataFrame(derivs_np, columns=d_cols)
        df_deriv.insert(0, "t_eval", t_eval.detach().cpu().numpy())
        df_deriv.to_csv(os.path.join(self.save_dir, "dX_derivative_sample.csv"), index=False)

        t_plot = torch.linspace(t_start, t_end, steps=100, device=coeffs.device)
        res_val = None
        res_denom = None
        
        if hasattr(X_interp, 'get_debug_info'):
            res_val, res_denom, res_weights_dict = X_interp.get_debug_info(t_plot)
        else:
            res_val_list = []
            for t in t_plot:
                res_val_list.append(X_interp.evaluate(t))
            res_val = torch.stack(res_val_list, dim=1)
        
        if res_denom is not None:
            denom_flat = res_denom.detach().cpu().numpy().flatten()
            norm_stats = {
                'mean': float(np.mean(denom_flat)),
                'std': float(np.std(denom_flat)),
                'deviation_from_1_max': float(np.max(np.abs(denom_flat - 1.0)))
            }
            self._save_json(norm_stats, 'normalization_stats.json')
        
        fig, ax = plt.subplots(figsize=(12, 6))
        t_plot_np = t_plot.cpu().numpy()
        res_val_np = res_val.detach().cpu().numpy()
        
        num_heads = getattr(model, 'num_heads', 1)
        heads_to_plot = min(num_heads, 4)
        
        for h in range(heads_to_plot):
            if h >= res_val_np.shape[0]: break
            feat_idx = 1 if self.add_time else 0
            if feat_idx < res_val_np.shape[2]:
                data_trace = res_val_np[h, :, feat_idx]
                ax.plot(t_plot_np, data_trace, label=f'Head {h} Feat {feat_idx}')
            
            if self.add_time and h == 0:
                 time_trace = res_val_np[h, :, 0]
                 ax.plot(t_plot_np, time_trace, '--', label='Time Channel', alpha=0.5)

        ax.set_title(f"Reconstructed Trajectories X(t)\nModel: {type(model).__name__}, AddTime: {self.add_time}")
        ax.legend()
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, "trajectories_sample0.png"))
        plt.close()

        if hasattr(X_interp, 'weights'):
            raw_w = X_interp.weights
            if raw_w.size(0) >= num_heads:
                sample0_w = raw_w[:num_heads, :, 0].detach().cpu().numpy()
                plt.figure(figsize=(12, 6))
                sns.heatmap(sample0_w, cmap='viridis', robust=True)
                plt.title(f"Attention Weights Heatmap (Sample 0)")
                plt.savefig(os.path.join(self.save_dir, "weights_heatmap.png"))
                plt.close()

    def _analyze_hidden_states(self, model, coeffs):
        hidden_activations = []
        def hook_fn(module, input, output):
            hidden_activations.append(input[0].detach().cpu())
            
        handle = model.readout.register_forward_hook(hook_fn)
        _ = model(coeffs)
        handle.remove()
        
        if hidden_activations:
            z_T = hidden_activations[0]
            stats = {
                'shape': list(z_T.shape),
                'mean': float(z_T.mean()),
                'std': float(z_T.std()),
                'norm_L2_mean': float(torch.norm(z_T, dim=1).mean())
            }
            self._save_json(stats, 'hidden_state_stats.json')

        if not hasattr(model, 'make_interpolation'): return

        with torch.no_grad():
            X = model.make_interpolation(coeffs)
            X0 = X.evaluate(X.interval[0]) 
            
            if hasattr(model, 'initial'):
                z0 = model.initial(X0)
            elif hasattr(model, 'initial_layer'):
                z0 = model.initial_layer(X0)
            else:
                return 
            
            func = getattr(model, 'func', getattr(model, 'cde_func', None))
            if func is None: return
            
            vec_field_out = func(X.interval[0], z0) 
            
            vf_stats = {}
            vf_stats['vf_shape'] = list(vec_field_out.shape)
            
            if self.add_time:
                vf_time = vec_field_out[..., 0]   
                vf_feats = vec_field_out[..., 1:] 
                
                mag_time = vf_time.norm(p=2, dim=-1).mean().item() if vf_time.dim() > 1 else vf_time.abs().mean().item()
                mag_feats = vf_feats.norm(p=2, dim=-1).mean().item()
                
                vf_stats['time_sensitivity_mean'] = mag_time
                vf_stats['feature_sensitivity_mean'] = mag_feats
                vf_stats['ratio_time_to_features'] = mag_time / (mag_feats + 1e-9)
            else:
                mag_feats = vec_field_out.norm(p=2, dim=-1).mean().item()
                vf_stats['time_sensitivity_mean'] = 0.0
                vf_stats['feature_sensitivity_mean'] = mag_feats
                vf_stats['ratio_time_to_features'] = 0.0

            self._save_json(vf_stats, 'vector_field_stats.json')