import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from geomloss import SamplesLoss
import math
from data import *
from network import *
from jump_utils import *
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--rho', type=float, default=0.03)
parser.add_argument('--sigma', type=float, default=0.3)
parser.add_argument('--path', type=str, default="jump_rho")
parser.add_argument('--no_epochs', type=int, default=300)
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--data_size', type=int, default=128*1000)
parser.add_argument('--val_size', type=int, default=5000)
parser.add_argument('--no_timesteps', type=int, default=50)
parser.add_argument('--disc_steps', type=int, default=10)
parser.add_argument('--memory_length', type=int, default=10)
parser.add_argument('--manual_seed', type=int, default=0)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--smoothing_factor', type=float, default=1e-3)
parser.add_argument('--data_set', type=str, default="toy")
parser.add_argument('--start_sig', type=float, default=1)
parser.add_argument('--loss_function', type=str, default="ikl")
parser.add_argument("--subsample_time", type=int, default=10)
parser.add_argument("--equidist", type=str, default="False")
args = parser.parse_args()
args = parser.parse_args()

#data_creater = synthetic_data(args.data_set)
device = "cuda"
no_timesteps = args.no_timesteps
disc_steps = args.disc_steps
memory_length = args.memory_length
val_size= args.val_size
sigma= args.sigma
rho = args.rho
dataset = args.data_set
def create_memory_sampling(index, times, memory_length, traj, disc_steps):
    indices = times[times <= index]
    indices = indices[-memory_length:]
    mem_cand = traj[:, ::disc_steps, :]
    mem = mem_cand[:, indices]
    if mem.shape[1] < memory_length:
        pad_length = memory_length - mem.shape[1]
        padding = torch.ones(traj.shape[0], pad_length, 1, device=traj.device) * mem[:, :1, :]
        mem = torch.cat((padding, mem), dim=1)
        padding = torch.zeros(pad_length, device=traj.device)
        indices = torch.cat((padding, indices), dim=0)
    mem = torch.cat((mem.squeeze(), indices.repeat(traj.shape[0], 1)), 1)
    return mem

def euler_markov(alpha, net, net_jump, no_samples, no_timesteps, times, disc_steps,
                 memory_length, sigma=0.1, rho=0.1, initial_std=1., dimension=1):
    x = torch.randn(no_samples, dimension, device=device)
    traj = torch.zeros((no_samples, disc_steps * no_timesteps, dimension), device=device)
    times = torch.cat((torch.zeros(1, device=device), times, torch.ones(1, device=device) * no_timesteps)).to(torch.int64)

    for j in range(no_timesteps):
        t2 = times[(j < times)][0] * torch.ones(no_samples, 1, device=device)
        for i in range(disc_steps):
            traj[:, j * disc_steps + i, :] = x
            t = (i / disc_steps) * torch.ones(no_samples, 1, device=device)
            h = 1 / disc_steps * torch.ones(no_samples, 1, device=device)
            mem = create_memory_sampling(j, times, memory_length, traj, disc_steps)
            time = torch.ones_like(t) * j + t
            inpu = torch.cat((t2, time, x, mem.reshape(no_samples, -1)), 1)

            out = net_jump(inpu)
            lambd_t = torch.exp(out[:, 0].unsqueeze(1)).clamp(0., 1000)
            mean = out[:, 1]
            sig = torch.exp(out[:, 2])
            z = mean.unsqueeze(1) + sig.unsqueeze(1) * torch.randn(no_samples, 1, device=device)
            rt = torch.exp(-lambd_t* (1-alpha) * h)
            m = torch.bernoulli(1 - rt)
            x = x + h * alpha * net(inpu) + torch.sqrt(h * alpha) * math.sqrt(sigma) * torch.randn_like(x)
            x = (1 - m) * x + m * z
            #x = torch.clamp(x, 0, 1000)
    return traj

mmd = SamplesLoss("energy")
subsample_times = [50, 25, 10,5]
n_seeds = 5
alphas = [k / 20 for k in range(21)]
for ss in subsample_times:
    print(f"\n=== Subsample time: {ss} ===")
    all_mmds = {alpha: [] for alpha in alphas}

    for seed in range(n_seeds):
        print(f"\n--- Seed {seed} ---")

        sde_path = f"sde_rho/{args.data_set}_rho{rho}_sig{sigma}_sub{ss}_loss{args.loss_function}_{seed}/model_jump.pt"
        jump_path = f"jump_rho/{args.data_set}_rho{rho}_sig{sigma}_sub{ss}_loss{args.loss_function}_{seed}/model_jump.pt"
        if not os.path.exists(sde_path) or not os.path.exists(jump_path):
            print(f"Missing model for seed {seed}, skipping...")
            continue

        # Load data and networks
        data_creater = load_data(seed)
        dataloader, val_set, test_set, times_eval = data_creater.get_data(ss, no_timesteps, batch_size=args.batch_size, device=device)

        net = create_mlp_3(memory_length, 256, out=1).to(device)
        net_jump = create_mlp_jump_gauss_3(memory_length, 256).to(device)
        net.load_state_dict(torch.load(sde_path))
        net_jump.load_state_dict(torch.load(jump_path))

        for alpha in alphas:
            with torch.no_grad():
                samples = euler_markov(alpha, net, net_jump, val_size, no_timesteps, times_eval, disc_steps,
                                       memory_length, sigma, rho)
                samples = samples[:, ::disc_steps].squeeze()
                val_mmd = mmd(samples, val_set).item()
                all_mmds[alpha].append(val_mmd)
                print(f"Alpha {alpha:.2f}: Val MMD = {val_mmd:.6f}")

    # Compute best alpha
    averaged_mmds = {alpha: np.mean(vals) for alpha, vals in all_mmds.items()}
    std_mmds = {alpha: np.std(vals) for alpha, vals in all_mmds.items()}
    best_alpha = min(averaged_mmds, key=averaged_mmds.get)
    best_val_mmd = averaged_mmds[best_alpha]
    best_val_std = std_mmds[best_alpha]

    print(f"\n>>> Best alpha: {best_alpha:.2f}, Avg Val MMD: {best_val_mmd:.6f}, Std: {best_val_std:.6f}")

    # Compute test MMDs for alpha=0, 1, best_alpha
    test_results = {}
    for alpha in [0.0, 1.0, best_alpha]:
        mmd_vals = []
        for seed in range(n_seeds):
            data_creater = load_data(seed)
            _, _, test_set, times_eval = data_creater.get_data(ss, no_timesteps, batch_size=args.batch_size, device=device)
            net = create_mlp_3(memory_length, 256, out=1).to(device)
            net_jump = create_mlp_jump_gauss_3(memory_length, 256).to(device)
            sde_path = f"sde_rho/{args.data_set}_rho{rho}_sig{sigma}_sub{ss}_loss{args.loss_function}_{seed}/model_jump.pt"
            jump_path = f"jump_rho/{args.data_set}_rho{rho}_sig{sigma}_sub{ss}_loss{args.loss_function}_{seed}/model_jump.pt"
            net.load_state_dict(torch.load(sde_path))
            net_jump.load_state_dict(torch.load(jump_path))

            with torch.no_grad():
                samples = euler_markov(alpha, net, net_jump, val_size, no_timesteps, times_eval, disc_steps,
                                       memory_length, sigma, rho)
                samples = samples[:, ::disc_steps].squeeze()
                test_mmd = mmd(test_set, samples).item()
                mmd_vals.append(test_mmd)

                if seed == 0:
                    save_dir = f"results_final/{args.data_set}_rho{rho}_sig{sigma}_sub{ss}_loss{args.loss_function}_{seed}"
                    os.makedirs(save_dir, exist_ok=True)
                    plt.figure()
                    for k in range(min(128, samples.shape[0])):
                        plt.plot(samples[k].cpu().numpy().squeeze())
                    plt.tight_layout()
                    plt.savefig(f"{save_dir}/samples_alpha_{alpha:.2f}.png")
                    plt.close()
        test_results[alpha] = (np.mean(mmd_vals), np.std(mmd_vals), mmd_vals)

    # Save results
    save_dir = f"results_final/{args.data_set}_rho{rho}_sig{sigma}_sub{ss}_loss{args.loss_function}_summary"
    os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir, "results.txt"), "w") as f:
        f.write(f"Best alpha: {best_alpha:.2f}\n")
        f.write(f"Best validation MMD avg: {best_val_mmd:.6f}, std: {best_val_std:.6f}\n\n")
        for alpha in [0.0, 1.0, best_alpha]:
            mean_val, std_val, all_vals = test_results[alpha]
            f.write(f"Test MMD alpha={alpha:.2f}: avg={mean_val:.6f}, std={std_val:.6f}\n")
            f.write(f"  Values: {', '.join(f'{v:.6f}' for v in all_vals)}\n")

        f.write("\nAll Validation MMDs:\n")
        for alpha in sorted(all_mmds.keys()):
            vals = all_mmds[alpha]
            f.write(f"{alpha:.2f}: avg={np.mean(vals):.6f}, std={np.std(vals):.6f}, vals={', '.join([f'{v:.6f}' for v in vals])}\n")
