# %% IMPORTS
## automatic reaload of libraried if changed
%load_ext autoreload
%autoreload 2

import os
import sys
import numpy as np
sys.path.append("..")
sys.path.append("../ALAE")
from alae_ffhq_inference import load_model, encode, decode
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchdiffeq import odeint
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import torch.nn as nn
from tqdm import tqdm
from geomloss import SamplesLoss
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.collections import LineCollection
from datetime import datetime
from torch.utils.data import DataLoader, TensorDataset
from frechetdist import frdist

# %% DATA LOADING
train_size = 60000
test_size = 10000

latents = np.load("../data/latents.npy")
gender = np.load("../data/gender.npy")
age = np.load("../data/age.npy")
test_inp_images = np.load("../data/test_images.npy")

train_latents, test_latents = latents[:train_size], latents[train_size:]
train_latents_ds = TensorDataset(torch.from_numpy(train_latents)) 
test_latents_ds = TensorDataset(torch.from_numpy(test_latents)) 
train_gender, test_gender = gender[:train_size], gender[train_size:]

## convert str to unique int labels
unique_labels = np.unique(train_gender)
label_to_int = {label: i for i, label in enumerate(unique_labels)}
train_gender_int = np.array([label_to_int[label] for label in train_gender])
test_gender_int = np.array([label_to_int[label] for label in test_gender])

labels = np.arange(len(unique_labels))
print(labels, unique_labels)
train_gender_int_ds = TensorDataset(torch.from_numpy(train_gender_int)) 
test_gender_int_ds = TensorDataset(torch.from_numpy(test_gender_int)) 

# Torch versions for indexing in sampling helpers
train_gender_int_t = torch.from_numpy(train_gender_int).long()
train_latents_t = torch.from_numpy(train_latents).float()
# test tensors for class-conditional sampling in evaluation
test_gender_int_t = torch.from_numpy(test_gender_int).long()
test_latents_t = torch.from_numpy(test_latents).float()

# train_age, test_age = age[:train_size], age[train_size:]
data_dim = train_latents.shape[1]

def sample_source(batch_size, device):
    return torch.randn(batch_size, data_dim, device=device)

def sample_source_class(batch_size, device, label):
    # get all samples with given integer label
    label_val = int(label.item()) if isinstance(label, torch.Tensor) else int(label)
    idx = torch.where(train_gender_int_t == label_val)[0]
    if idx.numel() == 0:
        raise ValueError(f"No samples found for label {label_val}. Check label mapping printed earlier.")
    perm = torch.randperm(idx.numel(), device=idx.device)[:batch_size]
    idx = idx[perm]
    return train_latents_t[idx].to(device)

def sample_target_class(batch_size, device, label):
    # get all samples with given integer label
    label_val = int(label.item()) if isinstance(label, torch.Tensor) else int(label)
    idx = torch.where(train_gender_int_t == label_val)[0]
    if idx.numel() == 0:
        raise ValueError(f"No samples found for label {label_val}. Check label mapping printed earlier.")
    perm = torch.randperm(idx.numel(), device=idx.device)[:batch_size]
    idx = idx[perm]
    return train_latents_t[idx].to(device)

def sample_target(batch_size, device, labels=None):
    if labels is None:
        idx = torch.randint(0, train_latents.shape[0], (batch_size,), device=device)
        idx_np = idx.cpu().numpy()
        return torch.from_numpy(train_latents[idx_np]).float().to(device)
    else:
        idx = torch.randint(0, train_latents.shape[0], (batch_size,), device=device)
        idx_np = idx.cpu().numpy()
        x = torch.from_numpy(train_latents[idx_np]).float().to(device)
        labels_np = np.asarray(labels)
        y = torch.from_numpy(labels_np[idx_np]).float().to(device)
        return x, y

def log_prob_source(x):
    mu = torch.zeros(data_dim, device=x.device)
    std = torch.ones(data_dim, device=x.device)
    var = std ** 2
    term = ((x - mu) ** 2) / var + torch.log(2 * torch.tensor(torch.pi, device=x.device) * var)
    return -0.5 * term.sum(dim=1)

def logit_normal_timestep_sample(P_mean: float, P_std: float, num_samples: int, device: torch.device) -> torch.Tensor:
    rnd_normal = torch.randn((num_samples,), device=device)
    time = torch.sigmoid(rnd_normal * P_std + P_mean)
    time = torch.clip(time, min=0.0, max=1.0)
    return time

# %% MODEL DEFINITION
class VectorField(nn.Module):
    def __init__(self, dim, hidden=128):
        super().__init__()
      
        self.time_embed = nn.Linear(1, hidden)

        hidden_in = dim + hidden
        self.net = nn.Sequential(
            nn.Linear(hidden_in, hidden), nn.Tanh(),
            nn.Linear(hidden, hidden), nn.Tanh(),
            nn.Linear(hidden, dim)
        )
    def forward(self, t, x):
        # Allow t to be a tensor for flow matching
        if t.numel() == 1:
            t = t.expand(x.shape[0], 1)
        # Concatenate time to each sample
        t = self.time_embed(t)
        tx = torch.cat([x, t], dim=1)
        return self.net(tx)

class MeanFlowMLP(nn.Module):
    def __init__(self, dim, hidden=128):
        super().__init__()

        self.time_embed = nn.Linear(1, hidden)
        
        hidden_in = dim + 1 + hidden
        self.net = nn.Sequential(
            nn.Linear(hidden_in, hidden), nn.Tanh(),
            nn.Linear(hidden, hidden), nn.Tanh(),
            nn.Linear(hidden, dim)
        )
    
    def forward(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        if t.numel() == 1:
            t = t.expand(z.shape[0], 1)
        t = self.time_embed(t)
        inp = torch.cat([z, r, t], dim=-1) # (B, 4)
        return self.net(inp)

class ODEFunc(nn.Module):
    def __init__(self, vf):
        super().__init__()
        self.vf = vf
    def forward(self, t, x):
        return self.vf(t, x) 
class ODEFuncMFM(nn.Module):
    def __init__(self, vf, r_value: float = 0.0):
        super().__init__()
        self.vf = vf
        self.r_value = r_value
    def forward(self, t, x):
        B = x.shape[0]
        r = torch.full((B, 1), self.r_value, device=x.device, dtype=x.dtype)
        t_batch = t.reshape(1, 1).repeat(B, 1) if isinstance(t, torch.Tensor) else torch.tensor(t, device=x.device, dtype=x.dtype).reshape(1, 1).repeat(B, 1)
        return self.vf(x, r, (t_batch - r))

model = load_model("../ALAE/configs/ffhq.yaml", training_artifacts_dir="../ALAE/training_artifacts/ffhq/")
model.eval()



# %% UTILS
def integrate_ode(x0, model, steps, beta=0.7):
    x = x0
    batch = x.shape[0]
    total_potential = 0.0
    total_kinetic = 0.0
    DT = 1 / steps
    for k in range(steps):
        t = torch.full((batch,1), (k / steps), device=x.device, dtype=x.dtype)
        v = model(t, x)
        pot = (v.norm(dim=1) ** beta)  
        total_potential = total_potential + pot.mean()
        x = x + DT * v
    return x, total_potential * DT

def plot_trajectories(x0, odefunc, steps, device, TYPE, it=None, SAVE_DIR=None):
    t_span = torch.linspace(0., 1., steps, device=device)
    sol = odeint(odefunc, x0, t_span)
    T, B, D = sol.shape
    v_sol = sol.view(T * B, D).detach().cpu().numpy()
    pca = PCA(n_components=2)
    pca.fit(v_sol)
    v_sol_pca = pca.transform(v_sol)
    v_sol_pca = v_sol_pca.reshape(T, B, 2)

    SOURCE_COLOR = '#87CEEB'       # Sky Blue
    SOURCE_OUTLINE = '#2E4B8F'    # Dark Blue
    TARGET_COLOR = '#191970'       # Midnight Blue
    TARGET_OUTLINE = '#FFFFFF'    # White for contrast
    PLOT_BG_COLOR = (248/255, 248/255, 248/255, 0.5) # Very light grey

    line_seed = 42
    num_lines = 300

    T, B, D = v_sol_pca.shape
    cmap = LinearSegmentedColormap.from_list(
        "source_to_target",
        [SOURCE_COLOR, TARGET_COLOR],
        N=256,
    )

    norm = plt.Normalize(0.0, 1.0)
    t_colors = np.linspace(0.0, 1.0, max(2, T - 1))

    # Choose which trajectories to draw fully with lines
    rng = np.random.default_rng(line_seed)
    if num_lines is None or num_lines < 0 or num_lines >= B:
        line_indices = np.arange(B)
    else:
        num_lines = max(0, min(num_lines, B))
        line_indices = rng.choice(B, size=num_lines, replace=False)
    line_index_set = set(int(i) for i in np.atleast_1d(line_indices))

    fig, ax = plt.subplots(figsize=(14, 14), dpi=400)
    fig.patch.set_facecolor('white')
    ax.set_facecolor(PLOT_BG_COLOR)
    for spine in ax.spines.values():
        spine.set_visible(False)
    for b in range(B):
        coords = v_sol_pca[:, b, :]
        if b in line_index_set:
            segments = np.stack([coords[:-1], coords[1:]], axis=1)  # (T-1, 2, 2)
            lc = LineCollection(
                segments, cmap=cmap, norm=norm, linewidth=2.0, alpha=0.9
            )
            lc.set_array(t_colors)
            ax.add_collection(lc)
        ax.scatter(
            coords[0:1, 0],
            coords[0:1, 1],
                c=[SOURCE_COLOR],
            s=80,
                marker="x",
            linewidths=1.5,
                edgecolors=SOURCE_OUTLINE,
            alpha=0.8,
                label=("Source" if b == 0 else None),
        )
        ax.scatter(
            coords[-1:, 0],
            coords[-1:, 1],
            c=[TARGET_COLOR],
            s=80,
                marker="o",
            edgecolors=TARGET_OUTLINE,
            linewidths=1.5,
                label=("Target" if b == 0 else None),
        )

    ax.autoscale()
    ax.set_aspect("equal", adjustable="datalim")
    ax.grid(True, alpha=0.3)
    ax.legend(frameon=False, fontsize=30)
    ax.tick_params(axis='both', which='major', labelsize=16)
    fig.tight_layout()
    if it is not None:
        plt.savefig(f"{SAVE_DIR}/traj_pca_lines_{TYPE}_{it}.png")
    else:
        plt.savefig(f"traj_pca_lines_{TYPE}.png")
    plt.close()

def generate_samples(x0, odefunc, steps, device, TYPE, batch_size, it=None, SAVE_DIR=None, train_mode=None): 
    print("inside generate_samples")
    t_span = torch.linspace(0., 1., steps, device=device)
    sol = odeint(odefunc, x0, t_span)
    x_1_pred = sol[-1]

    # decode
    with torch.no_grad():
        model = load_model("../ALAE/configs/ffhq.yaml", training_artifacts_dir="../ALAE/training_artifacts/ffhq/")
        model.eval()
        model = model.to(device)

    mode = (train_mode or "generation").lower()
    print("mode", mode)

    if mode == "generation":
        print("inside generation")
        # Rectangular grid for generated images
        cols = int(np.ceil(np.sqrt(batch_size)))
        rows = int(np.ceil(batch_size / cols))
        fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))
        axes = axes.flatten() if batch_size > 1 else [axes]

        for i in range(batch_size):
            img_pred = decode(model, x_1_pred[i].unsqueeze(0).to("cpu"))
            img_pred = ((img_pred * 0.5 + 0.5) * 255).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]
            ax = axes[i]
            ax.imshow(img_pred)
            ax.axis('off')
        fig.tight_layout()
        if it is not None:
            fig.savefig(f"{SAVE_DIR}/ode_{TYPE}_{it}.png", bbox_inches="tight", pad_inches=0)
        else:
            fig.savefig(f"ode_{TYPE}.png", bbox_inches="tight", pad_inches=0)
        plt.close(fig)

    elif mode == "traslation":
        print("inside traslation")
        # Tiled blocks: for each block of columns, show two rows (x0 on top, x1_pred bottom)
        max_cols = min(16, batch_size)
        num_blocks = int(np.ceil(batch_size / max_cols))
        rows = 2 * num_blocks
        cols = max_cols

        fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))
        if rows == 1 and cols == 1:
            axes = np.array([[axes]])
        elif rows == 1:
            axes = axes.reshape(1, cols)
        elif cols == 1:
            axes = axes.reshape(rows, 1)

        for i in range(batch_size):
            # print("i", i)
            block = i // max_cols
            col = i % max_cols
            top_row = 2 * block
            bottom_row = top_row + 1

            x0_img = decode(model, x0[i].unsqueeze(0).to("cpu"))
            x0_img = ((x0_img * 0.5 + 0.5) * 255).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]
            x1_img = decode(model, x_1_pred[i].unsqueeze(0).to("cpu"))
            x1_img = ((x1_img * 0.5 + 0.5) * 255).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]

            axes[top_row, col].imshow(x0_img)
            axes[top_row, col].axis('off')
            axes[bottom_row, col].imshow(x1_img)
            axes[bottom_row, col].axis('off')

        fig.tight_layout()
        if it is not None:
            fig.savefig(f"{SAVE_DIR}/ode_{TYPE}_{it}.png", bbox_inches="tight", pad_inches=0)
        else:
            fig.savefig(f"ode_{TYPE}.png", bbox_inches="tight", pad_inches=0)
        plt.close(fig)




# %% TRAINING
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")


# Y-flows experimental params
STEPS = 10
LAMBDA_SINKHORN = 3.5 
BETA = 0.7

# Training params
n_iters = 100000 
batch_size = 256 
lr = 1e-4
hidden = 1024 
ratio = 0.75


TYPE = "FM"
TRAIN_MODE = "TRASLATION"
assert TYPE in ["FM", "MFM", "YF"]
assert TRAIN_MODE in ["GENERATION", "TRASLATION"]

## set labels via mapping if available
try:
    female_idx = int(label_to_int.get('female', 1))
    male_idx = int(label_to_int.get('male', 0))
except Exception:
    female_idx = 1
    male_idx = 0
label_1 = torch.tensor(female_idx, dtype=torch.long)  # source class
label_2 = torch.tensor(male_idx, dtype=torch.long)    # target class

SAVE_DIR = f"{TYPE}/{TRAIN_MODE}/{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
os.makedirs(SAVE_DIR, exist_ok=True)

if TYPE == "MFM":
    fm = MeanFlowMLP(dim=data_dim, hidden=hidden).to(device)
else:
    fm = VectorField(dim=data_dim, hidden=hidden).to(device)

fm.train()
optimizer = torch.optim.Adam(fm.parameters(), lr=lr)
criterion = nn.MSELoss()
sinkhorn_loss = SamplesLoss(
    loss="sinkhorn", p=2, blur=0.05, backend="tensorized")


print(f"Starting {TYPE} training...")
for it in range(n_iters):

    optimizer.zero_grad()
    
    # 1. Sample source, target, and time
    if TRAIN_MODE == "GENERATION":
        x_0 = sample_source(batch_size, device)
        x_1 = sample_target(batch_size, device)
    elif TRAIN_MODE == "TRASLATION":
        x_0 = sample_source_class(batch_size, device, label_1)
        x_1 = sample_target_class(batch_size, device, label_2)
    else:
        raise ValueError(f"Invalid train mode: {TRAIN_MODE}")


    if TYPE in ["FM"]:
        ### TODO: change this to logit_normal_timestep_sample
        t = torch.rand(batch_size, 1, device=device)
    elif TYPE == "MFM":
        t = logit_normal_timestep_sample(-0.6, 1.6, batch_size, device)
        r = logit_normal_timestep_sample(-4.0, 1.6, batch_size, device)
        t, r = torch.maximum(t, r), torch.minimum(t, r)
        # make t and r different with a probability of args.ratio
        prob = torch.rand(batch_size, device=device)
        mask = prob < 1 - ratio
        r = torch.where(mask, t, r)
        t = t.unsqueeze(1)
        r = r.unsqueeze(1)
    elif TYPE =="YF":
        t = torch.rand(batch_size, 1, device=device)
        xT, pot_term = integrate_ode(x_0, fm, steps=STEPS, beta=BETA)
    else: 
        raise ValueError(f"Invalid type: {TYPE}")

    
    # 2. Sample points on the linear interpolation path
    x_t = (1 - t) * x_0 + t * x_1
    v = x_1 - x_0
    

    if TYPE == "MFM":        # 4. Compute the gradient of the target vector field
        def u_func(z, t, r):
            return fm(z, r, (t - r))
        dtdt = torch.ones_like(t)
        drdt = torch.zeros_like(r)
        pred, dudt = torch.func.jvp(u_func, (x_t, t, r), (v, dtdt, drdt))
        u_tgt = (v - (t - r) * dudt).detach()
    
    # 4. Get the model's prediction
    if TYPE in ["FM"]:
        pred = fm(t, x_t)
    

    if TYPE == "FM":
        loss = criterion(pred, v) 
    elif TYPE == "MFM":
        loss = criterion(pred, u_tgt)
    elif TYPE == "YF":
        loss_sink = sinkhorn_loss(xT, x_1)
        loss = LAMBDA_SINKHORN * loss_sink + pot_term

    loss.backward()
    optimizer.step()
    
    if it % 500 == 0:
        with torch.no_grad():
            # Vector-field alignment metric
            if TYPE == "MFM":
                pred_v = fm(x_t, t, r)
            else:
                pred_v = fm(t, x_t)

            target_v = v if TYPE == "FM" else u_tgt if TYPE == "MFM" else xT
            cos_sim = F.cosine_similarity(pred_v, target_v, dim=1).mean()
            ang = torch.arccos(cos_sim.clamp(-1.0 + 1e-6, 1.0 - 1e-6))
            ang_deg = torch.rad2deg(ang)
            torch.save(fm.state_dict(), f"{TYPE}.pth")
        print(f"Method: {TYPE} | Iter {it:05d} | Loss: {loss.item():.6f} | cos={cos_sim.item():.4f} | ang_deg={ang_deg.item():.2f}")
        with open(f"{SAVE_DIR}/args.txt", "a") as f:
            f.write(f"Method: {TYPE} | Iter {it:05d} | Loss: {loss.item():.6f} | cos={cos_sim.item():.4f} | ang_deg={ang_deg.item():.2f}\n")


    if it % 2000 == 0:
        fm.eval()
        if TYPE == "MFM":
            odefunc = ODEFuncMFM(fm.to(device), r_value=0.0)
        else:
            odefunc = ODEFunc(fm.to(device))
        
        if TRAIN_MODE == "GENERATION":
            x0 = sample_source(batch_size, device)
            x1 = sample_target(batch_size, device)
        elif TRAIN_MODE == "TRASLATION":
            x0 = sample_source_class(batch_size, device, label_1)
            x1 = sample_target_class(batch_size, device, label_2)
        else:
            raise ValueError(f"Invalid train mode: {TRAIN_MODE}")

        plot_trajectories(x0, odefunc, STEPS, device, TYPE, it=it, SAVE_DIR=SAVE_DIR)
        generate_samples(x0, odefunc, STEPS, device, TYPE, batch_size, it=it, SAVE_DIR=SAVE_DIR, train_mode=TRAIN_MODE)
        torch.save(fm.state_dict(), f"{SAVE_DIR}/model_iter_{it}.pth")
        fm.train()



