import os
import torch
import numpy as np
import argparse
from geomloss import SamplesLoss


# local imports
from data_stock import load_data
#from sde_utils import solve_sde
from write_results import WriteResults
import sde_matching_modules
from sde_matching_modules import (
    PriorSDE,
    PosteriorEncoder,
    PosteriorAffine,
    MatchingSDE,
    solve_sde,
    mmd_metric,
    sinkhorn_dist
)
from sde_matching_modules_tunable import (
    PriorInitDistribution,
    IdentityObservation
)
# -------------------------------
# Command-line args
# -------------------------------
parser = argparse.ArgumentParser()
parser.add_argument('--rho', type=float, default=0.001)
parser.add_argument('--sigma', type=float, default=0.3)
parser.add_argument('--path', type=str, default="sde_rho")
parser.add_argument('--no_epochs', type=int, default=300)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--data_size', type=int, default=128*10)
parser.add_argument('--val_size', type=int, default=5)
parser.add_argument('--no_timesteps', type=int, default=101)
parser.add_argument('--disc_steps', type=int, default=10)
parser.add_argument('--memory_length', type=int, default=20)
parser.add_argument('--manual_seed', type=int, default=0)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--smoothing_factor', type=float, default=1e-3)
parser.add_argument('--data_set', type=str, default="BlackScholes")
parser.add_argument('--start_sig', type=float, default=1.0)
parser.add_argument('--loss_function', type=str, default="ikl")
parser.add_argument("--subsample_time", type=int, default=25)
parser.add_argument("--equidist", type=str, default="False")
#values for sde_matching
parser.add_argument("--sigma_a", type=float, default=-2.0, help="log std of init prior")
parser.add_argument("--sigma_o", type=float, default=0.03, help="obs noise std")

args = parser.parse_args()

# -------------------------------
# Setup
# -------------------------------
torch.manual_seed(args.manual_seed)
torch.cuda.manual_seed(args.manual_seed)
torch.cuda.manual_seed_all(args.manual_seed)
device = "cuda" if torch.cuda.is_available() else "cpu"

short_id = f"{args.path}/{args.data_set}_rho{args.rho}_sig{args.sigma}_sub{args.subsample_time}_loss{args.loss_function}_{args.manual_seed}"
os.makedirs(short_id, exist_ok=True)
write_results = WriteResults(args, short_id)

print("Saving results to:", write_results.path)

# -------------------------------
# Data
# -------------------------------
data_creater = load_data(args.manual_seed)
dataloader, val_set, test_set, times_eval = data_creater.get_data(
    args.subsample_time,
    args.no_timesteps,
    batch_size=args.batch_size,
    device=device
)


# -------------------------------
# Model components
# -------------------------------
data_size = 1   # stock paths are 1D
latent_size = 1
hidden_size = 500

p_init_distr = PriorInitDistribution(latent_size, log_s_init=args.sigma_a)
p_sde = PriorSDE(latent_size, hidden_size)
#p_observe = PriorObservation(latent_size, data_size, noise_std=args.sigma)
#p_observe = PriorObservation(latent_size, data_size, noise_std=0.001)
p_observe = IdentityObservation(eps=args.sigma_o)


q_enc = PosteriorEncoder(latent_size,hidden_size)
q_affine = PosteriorAffine(latent_size, hidden_size)

sde_matching = MatchingSDE(p_init_distr, p_sde, p_observe, q_enc, q_affine).to(device)
opti = torch.optim.Adam(sde_matching.parameters(), lr=args.lr)

mmd = SamplesLoss("energy")

# -------------------------------
# Training loop
# -------------------------------
def train_sde_matching(model, opti, dataloader, val_set, times_eval,
                       no_epochs, no_timesteps, disc_steps, output_path):
    best_sinkhorn = float("inf")
    loss_list, sinkhorn_list = [], []

    for epoch in range(no_epochs):
        avg_loss = 0
        sum_loss, sum_prior, sum_diff, sum_recon = 0.0, 0.0, 0.0, 0.0

        for k, x in enumerate(dataloader):
            xs = x[:, :, 0:1].to(device)  # (batch, T, 1)
            ts = x[:, :, 1:2].to(device)  # (batch, T, 1)
            

            #loss = model(xs, ts).mean()


            loss_total, loss_prior, loss_diff, loss_recon= model(xs, ts)
            loss = loss_total.mean()
            loss_prior = loss_prior.mean()
            loss_diff = loss_diff.mean()
            loss_recon = loss_recon.mean()

            opti.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
            opti.step()
            sum_loss  += loss.item()
            sum_prior += loss_prior.item()
            sum_diff  += loss_diff.item()
            sum_recon += loss_recon.item()

            #avg_loss = (k/(k+1))*avg_loss + (1/(k+1))*loss.item()

        avg_loss      = sum_loss  / len(dataloader)
        avg_loss_pri  = sum_prior / len(dataloader)
        avg_loss_diff = sum_diff  / len(dataloader)
        avg_loss_rec  = sum_recon / len(dataloader)

        # -------------------------------
        # Validation: generate from prior
        # -------------------------------
        with torch.no_grad():
            z0 = model.p_init_distr().rsample([val_set.shape[0]]).to(device)
            z0 = z0.squeeze(1)
            zs = solve_sde(model.p_sde, z0, ts=0.0, tf=1.0, n_steps=disc_steps * no_timesteps)  # (steps+1, batch, latent)
            zs = zs.permute(1, 0, 2)[:, :-1, :]                # (batch, steps, latent)

# ---- apply observation model ----
            zs_flat = zs.reshape(-1, latent_size)
            obs_dist = model.p_observe(zs_flat)
            xs = obs_dist.mean.reshape(val_set.shape[0], -1, data_size)
# ---- subsample to observation grid ----
            eval_trj = xs[:, ::disc_steps, 0]                  # (batch, no_timesteps)

# ---- compute validation metric ----
            mmd_val = mmd(val_set, eval_trj).item()
            #sinkhorn_val = sinkhorn_dist(val_set, eval_trj)

            sinkhorn_list.append(mmd_val) 
            prior_mean = model.p_init_distr.m.detach().cpu().item()
            prior_var = (torch.exp(model.p_init_distr.log_s)**2).detach().cpu().item()

            print(
                    f"[Epoch {epoch:03d}] "
                    f"TrainLoss={avg_loss:.4f} | "
                    f"LossApriori={avg_loss_pri:.4f} | "
                    f"LossDiff={avg_loss_diff:.4f} | "
                    f"LossRecon={avg_loss_rec:.4f} | "
                    f"ValSinkhorn={mmd_val:.4f} | "
                    f"Prior(mean={prior_mean:.3f}, var={prior_var:.3f})"
                )

            if mmd_val < best_sinkhorn:
                torch.save(model.state_dict(), os.path.join(output_path, "model_sde_matching.pt"))
                best_sinkhorn = mmd_val

        loss_list.append(avg_loss)

    write_results.plot_loss(loss_list)
    write_results.write_value(best_sinkhorn, "mmd_val")
    np.savetxt(os.path.join(write_results.path, "mmd_list.txt"), sinkhorn_list)
    return best_sinkhorn

# -------------------------------
# Run training
# -------------------------------
sinkhorn_val = train_sde_matching(
    sde_matching, opti, dataloader, val_set, times_eval,
    args.no_epochs, args.no_timesteps, args.disc_steps, write_results.path
)
print("FINAL VALIDATION Sinkhorn:", sinkhorn_val)



# -------------------------------
# Test evaluation
# -------------------------------

net_best = MatchingSDE(p_init_distr, p_sde, p_observe, q_enc, q_affine).to(device)
net_best.load_state_dict(torch.load(os.path.join(write_results.path, "model_sde_matching.pt")))

with torch.no_grad():
    test_size = test_set.shape[0]
    z0 = net_best.p_init_distr().rsample([val_set.shape[0]]).to(device)
    z0 = z0.squeeze(1)
    zs = solve_sde(net_best.p_sde, z0, ts=0.0, tf=1.0, n_steps=args.disc_steps * args.no_timesteps)  # (steps+1, batch, latent)
    zs = zs.permute(1, 0, 2)[:, :-1, :]                # (batch, steps, latent)

# ---- apply observation model ----
    zs_flat = zs.reshape(-1, latent_size)
    obs_dist = net_best.p_observe(zs_flat)
    xs = obs_dist.mean.reshape(val_set.shape[0], -1, data_size)
# ---- subsample to observation grid ----
    eval_trj = xs[:, ::args.disc_steps, 0]                  # (batch, no_timesteps)


    # Ensure val/test set shapes match
    test_sub = test_set[:, :eval_trj.shape[1]].to(device)

    # Compute scalar MMD
     # Compute scalar MMD
    mmd_test = mmd(test_sub, eval_trj).mean().item()


    write_results.write_image_traj(xs.cpu().numpy(),
                                   args.disc_steps, 101,
                                   f"samples_test_mmd_{round(mmd_test, 2)}_")
    write_results.write_value(mmd_test, "mmd_test")
    print("FINAL TEST mmd:", mmd_test)

