import sys
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import networkx as nx
from sklearn.datasets import load_digits
from sklearn.decomposition import PCA
from scipy.ndimage import rotate

from causalEGM import CausalEGM
from model import BaseFullyConnectedNet

class RingGeodesicCausalEGM(CausalEGM):
    def __init__(self, params, timestamp=None, random_seed=None):
        super().__init__(params, timestamp, random_seed)
        
        self.n_treatments = params['n_treatments']
        self.t_embed_dim = params['t_embed_dim']
        
        self.t_embed_layer = tf.keras.layers.Embedding(
            input_dim=self.n_treatments, output_dim=self.t_embed_dim,
            embeddings_initializer='uniform'
        )
        
        self.e_net = BaseFullyConnectedNet(
            input_dim=params['v_dim'], output_dim=sum(params['z_dims']), 
            model_name='e_net', nb_units=params['e_units']
        )
        
        f_in_dim = params['z_dims'][0] + params['z_dims'][1] + self.t_embed_dim
        self.f_net = BaseFullyConnectedNet(
            input_dim=f_in_dim, output_dim=1, 
            model_name='f_net', nb_units=params['f_units']
        )
        
        self.g_net = BaseFullyConnectedNet(
            input_dim=sum(params['z_dims']), output_dim=params['v_dim'], 
            model_name='g_net', nb_units=params['g_units']
        )

        self._build_dummy(params)

        self._build_ring_topology(params.get('geo_lambda', 1.0))

    def _build_dummy(self, params):
        dummy_v = tf.zeros((1, params['v_dim']))
        z = self.e_net(dummy_v)
        t = self.t_embed_layer(tf.zeros((1,), dtype=tf.int32))
        self.f_net(tf.concat([z[:, :params['z_dims'][0]+params['z_dims'][1]], t], axis=-1))
        self.g_net(z)

    def _build_ring_topology(self, geo_lambda):
        G = nx.cycle_graph(self.n_treatments)
        
        dist_matrix = dict(nx.all_pairs_shortest_path_length(G))
        
        target = np.zeros((self.n_treatments, self.n_treatments), dtype=np.float32)
        for i in range(self.n_treatments):
            for j in range(self.n_treatments):
                target[i, j] = dist_matrix[i][j]
                
        self.target_dist_matrix = tf.constant(target)
        self.geo_lambda = geo_lambda
        print(f"💍 Ring Topology Built (K={self.n_treatments}). 0 and {self.n_treatments-1} are neighbors.")

    def train(self, data, n_iter=5000):
        data_t, data_y, data_v = data
        
        data_t = np.array(data_t, dtype=np.int32)
        data_v = np.array(data_v, dtype=np.float32)
        y_mean, y_std = np.mean(data_y), np.std(data_y)
        data_y_norm = (data_y - y_mean) / (y_std + 1e-8)
        
        optimizer = tf.keras.optimizers.Adam(learning_rate=self.params['lr'])
        print(f"🚀 Training on Rotated MNIST (Ring)...")
        
        for i in range(n_iter):
            idx = np.random.randint(0, data_t.shape[0], self.params['batch_size'])
            batch_t = tf.convert_to_tensor(data_t[idx], dtype=tf.int32)
            batch_v = tf.convert_to_tensor(data_v[idx], dtype=tf.float32)
            batch_y = tf.convert_to_tensor(data_y_norm[idx], dtype=tf.float32)
            
            with tf.GradientTape() as tape:
                z = self.e_net(batch_v)
                t_emb = self.t_embed_layer(batch_t)
                
                z_pred = z[:, :self.params['z_dims'][0]+self.params['z_dims'][1]]
                f_in = tf.concat([z_pred, t_emb], axis=-1)
                y_pred = self.f_net(f_in)
                
                v_rec = self.g_net(z)
                
                loss_mse = tf.reduce_mean(tf.square(batch_y - y_pred))
                loss_rec = tf.reduce_mean(tf.square(batch_v - v_rec))
                
                all_embs = self.t_embed_layer.trainable_variables[0]
                diff = tf.expand_dims(all_embs, 1) - tf.expand_dims(all_embs, 0)
                latent_dist = tf.norm(diff + 1e-6, axis=-1)
                loss_geo = tf.reduce_mean(tf.square(latent_dist - self.target_dist_matrix))
                
                total_loss = loss_mse + 1.0 * loss_rec + self.geo_lambda * loss_geo
                
            grads = tape.gradient(total_loss, self.e_net.trainable_variables + self.f_net.trainable_variables + 
                                  self.g_net.trainable_variables + self.t_embed_layer.trainable_variables)
            optimizer.apply_gradients(zip(grads, self.e_net.trainable_variables + self.f_net.trainable_variables + 
                                          self.g_net.trainable_variables + self.t_embed_layer.trainable_variables))
            
            if i % 1000 == 0:
                print(f"Iter {i:4d} | Rec: {loss_rec:.4f} | Geo: {loss_geo:.4f}")


def generate_rotated_mnist(n_samples=4000, K=8):
    digits = load_digits()
    idx_3 = np.where(digits.target == 3)[0]
    base_images = digits.images[idx_3]
    
    X_list = []
    T_list = []
    Y_list = []
    
    angles = np.linspace(0, 360, K, endpoint=False)
    
    for i in range(n_samples):
        img = base_images[np.random.randint(len(base_images))]
        
        t = np.random.randint(0, K)
        angle = angles[t]
        
        img_rot = rotate(img, angle, reshape=False, mode='nearest')
        img_flat = img_rot.flatten()
        
        img_flat = (img_flat - img_flat.min()) / (img_flat.max() - img_flat.min() + 1e-8)
        
        y = np.cos(np.deg2rad(angle)) + np.random.normal(0, 0.05)
        
        X_list.append(img_flat)
        T_list.append(t)
        Y_list.append(y)
        
    return np.array(T_list), np.array(Y_list), np.array(X_list)


class RingVisualizer:
    def __init__(self, model):
        self.model = model
        self.embs = model.t_embed_layer.get_weights()[0]
        self.K = model.n_treatments
        
    def plot_latent_ring(self):
        pca = PCA(n_components=2)
        embs_2d = pca.fit_transform(self.embs)
        
        plt.figure(figsize=(6, 6))
        for i in range(self.K):
            j = (i + 1) % self.K
            p1, p2 = embs_2d[i], embs_2d[j]
            plt.plot([p1[0], p2[0]], [p1[1], p2[1]], 'k--', alpha=0.5)
            
        plt.scatter(embs_2d[:, 0], embs_2d[:, 1], c=range(self.K), cmap='hsv', s=500, edgecolors='k')
        for i in range(self.K):
            plt.text(embs_2d[i, 0], embs_2d[i, 1], str(i*45)+"°", ha='center', va='center', fontweight='bold')
            
        plt.title("Latent Space: Cyclic Topology (Rotated MNIST)")
        plt.axis('equal')
        plt.show()
        
    def plot_image_interpolation(self, t_start=0, t_end=7):
        print(f"🎨 Generating Image Interpolation: {t_start*45}° -> {t_end*45}°")
        
        steps = 7
        alphas = np.linspace(0, 1, steps)
        
        e_start = self.embs[t_start]
        e_end = self.embs[t_end]
        
        dummy_v = np.zeros((1, 64))
        z_fixed = self.model.e_net(dummy_v)
        
        gen_imgs = []
        
        for alpha in alphas:
            e_interp = (1 - alpha) * e_start + alpha * e_end
            e_interp = tf.expand_dims(e_interp, 0)
            
            pass 

        self.plot_outcome_interpolation(t_start, t_end)

    def plot_outcome_interpolation(self, t_start, t_end):
        e_start = self.embs[t_start]
        e_end = self.embs[t_end]
        
        alphas = np.linspace(0, 1, 20)
        preds = []
        
        dummy_v = np.zeros((1, 64))
        dummy_v = tf.convert_to_tensor(dummy_v, dtype=tf.float32) 
        
        z = self.model.e_net(dummy_v)
        z_pred = z[:, :self.model.params['z_dims'][0]+self.model.params['z_dims'][1]]
        
        for alpha in alphas:
            e_interp = (1 - alpha) * e_start + alpha * e_end
            
            e_interp = tf.convert_to_tensor(e_interp, dtype=tf.float32)
            
            e_interp = tf.expand_dims(e_interp, 0)
            
            f_in = tf.concat([z_pred, e_interp], axis=-1)
            
            y = self.model.f_net(f_in).numpy()[0][0]
            preds.append(y)
            
        plt.figure(figsize=(8, 4))
        plt.plot(alphas, preds, 'r-o', linewidth=2, label='Geodesic Path')
        plt.title(f"Outcome Interpolation: {t_start*45}° -> {t_end*45}°")
        plt.xlabel("Interpolation Alpha")
        plt.ylabel("Outcome Y")
        plt.legend()
        plt.grid(True)
        plt.show()

if __name__ == "__main__":
    data = generate_rotated_mnist(n_samples=3000, K=8)
    
    params = {
        'dataset': 'Rotated_MNIST',
        'n_treatments': 8,
        't_embed_dim': 2,
        'v_dim': 64,
        'lr': 0.001,
        'batch_size': 64,
        'geo_lambda': 5.0,
        
        'z_dims': [4, 4, 4, 4],
        'e_units': [64, 64],
        'f_units': [64, 32],
        'g_units': [64, 64],
        
        'h_units': [1], 'q_units': [1], 'dv_units': [1], 'dz_units': [1],
        'binary_treatment': False, 'save_model': False, 'save_res': False, 'output_dir': '.'
    }
    
    model = RingGeodesicCausalEGM(params)
    model.train(data, n_iter=4000)
    
    viz = RingVisualizer(model)
    
    viz.plot_latent_ring()
    
    viz.plot_outcome_interpolation(0, 7)
    
    viz.plot_outcome_interpolation(0, 4)