import torch
import numpy as np
from data import *
from network import *
import matplotlib.pyplot as plt
from geomloss import SamplesLoss
import math
from write_results import *
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--rho', type=float, default=0.03)
parser.add_argument('--sigma', type=float, default=0.3)
parser.add_argument('--path', type=str, default="tfm")
parser.add_argument('--no_epochs', type=int, default=300)
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--data_size', type=int, default=128*1000)
parser.add_argument('--val_size', type=int, default=5000)
parser.add_argument('--no_timesteps', type=int, default=50)
parser.add_argument('--disc_steps', type=int, default=10)
parser.add_argument('--memory_length', type=int, default=10)
parser.add_argument('--manual_seed', type=int, default=0)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--smoothing_factor', type=float, default=1e-3)
parser.add_argument('--data_set', type=str, default="toy")
parser.add_argument('--start_sig', type=float, default=1)
parser.add_argument('--loss_function', type=str, default="ikl")
parser.add_argument("--subsample_time", type=int, default=10)
parser.add_argument("--equidist", type=str, default="False")
args = parser.parse_args()
args = parser.parse_args()

#data_creater = synthetic_data(args.data_set)
device = "cuda"
no_timesteps = args.no_timesteps
disc_steps = args.disc_steps
memory_length = args.memory_length
val_size= args.val_size
sigma= args.sigma
rho = args.rho


mmd = SamplesLoss("energy")
def create_memory_sampling(index, times, memory_length, traj, disc_steps):
        
    indices = times[times <= index]
    indices = indices[-memory_length:]
    mem_cand = traj[:,::disc_steps, :]  # Take memory in steps of size `disc_steps`
    mem = mem_cand[:,indices]
    if mem.shape[1] < memory_length:
        pad_length = memory_length - mem.shape[1]
        padding = torch.ones(traj.shape[0], pad_length, 1, device=traj.device)*mem[:,:1,:]
        mem = torch.cat((padding, mem), dim=1)
        padding = torch.zeros(pad_length, device=traj.device)
        indices = torch.cat((padding, indices),dim=0)
    mem = torch.cat((mem.squeeze(),indices.repeat(traj.shape[0],1)),1)
    return mem
def random_times(no_timesteps, subsample_time, random_seed = None, device="cuda", equidist="False"):
    if equidist=="True":
        jump_every = no_timesteps//subsample_time
        sorted_perm = torch.arange(0, no_timesteps, device=device)[::jump_every]
    else:
        if random_seed is not None:
            torch.manual_seed(random_seed)
        perm = torch.randperm(no_timesteps, device=device)[:subsample_time-1]#[:subsample_time-1]
        sorted_perm = perm.sort().values
    return sorted_perm

def euler(net, no_samples, no_timesteps, times,disc_steps, memory_length, sigma, rho, initial_std = 1., dimension = 1):
    x = torch.randn(no_samples, dimension, device=device)*initial_std
    traj = torch.zeros((no_samples, disc_steps * no_timesteps, dimension), device=device)
    times = torch.cat((torch.zeros(1, device=device), times,torch.ones(1, device=device)*no_timesteps)).to(torch.int64)
    for j in range(no_timesteps):
        t2 = times[(j<times)][0]*torch.ones(no_samples, 1, device=device)
        for i in range(disc_steps):
            traj[:, j * disc_steps + i, :] = x  # Store the current value of `x` in the trajectory

            t = (i / disc_steps) * torch.ones(no_samples, 1, device=device)
            h = 1 / disc_steps * torch.ones(no_samples, 1, device=device)

            mem = create_memory_sampling(j, times, memory_length, traj, disc_steps)
            time = (torch.ones_like(t) * j + (t)) 
            inpu = torch.cat((t2, time,x, mem.reshape(no_samples, -1)), 1)
            out = net(inpu)
         
            pred = out[:,:1]
            vel = (pred-x)/(torch.abs(t2-(t+j)))
            x = x + vel*h + torch.sqrt(h)*math.sqrt(sigma)*torch.randn_like(x)


    return traj
subsample_times = [50,25,10,5]
n_seeds = 5


for ss in subsample_times:
    print(f"\n=== Evaluating for subsample_time = {ss} ===")
    mmd_vals = []

    save_folder = f"results_tfm/subsample_time={ss}"
    os.makedirs(save_folder, exist_ok=True)

    for seed in range(n_seeds):
        folder = f"{args.path}/{args.data_set}_rho{rho}_sig{sigma}_sub{ss}_loss{args.loss_function}_{seed}"
        net_path = f"{folder}/model_jump.pt"
        if not os.path.exists(net_path):
            print(f"Skipping: {net_path} not found")
            continue

        print(f"\n--- Seed {seed} ---")
        torch.manual_seed(seed)
        data_creater = load_data(seed)
        _, _, test_set, times_eval = data_creater.get_data(ss, no_timesteps, batch_size=args.batch_size, device=device)

        net = create_mlp_3(memory_length, 256, out=1).to(device)
        net.load_state_dict(torch.load(net_path))

        with torch.no_grad():
            samples = euler(
                net=net,
                no_samples=val_size,
                no_timesteps=no_timesteps,
                times=times_eval,
                disc_steps=disc_steps,
                memory_length=memory_length,
                sigma=sigma,
                rho=rho
            )
            samples = samples[:, ::disc_steps].squeeze()
            mmd_val = mmd(test_set, samples).item()
            mmd_vals.append(mmd_val)

            if seed == 0:
                plt.figure()
                for k in range(min(128, samples.shape[0])):
                    plt.plot(samples[k].cpu().numpy().squeeze())
                plt.tight_layout()
                plt.savefig(f"{save_folder}/samples_sde_nr.png")
                plt.close()

    # === Save results ===
    avg_mmd = np.mean(mmd_vals)
    best_mmd = np.min(mmd_vals)
    std_mmd = np.std(mmd_vals)

    with open(f"{save_folder}/results.txt", "w") as f:
        f.write(f"SDE_NR Evaluation Results (subsample_time={ss})\n")
        f.write(f"Average Test MMD: {avg_mmd:.6f}\n")
        f.write(f"Best Test MMD: {best_mmd:.6f}\n")
        f.write(f"Std Dev: {std_mmd:.6f}\n")
        f.write("Individual Test MMDs:\n")
        for i, val in enumerate(mmd_vals):
            f.write(f"  Seed {i}: {val:.6f}\n")

