import torch
import numpy as np
from data import *
from network import *
from jump_utils import *
import matplotlib.pyplot as plt
from geomloss import SamplesLoss

from write_results import *
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--rho', type=float, default=0.03)
parser.add_argument('--sigma', type=float, default=0.3)
parser.add_argument('--path', type=str, default="jump_rho")
parser.add_argument('--no_epochs', type=int, default=300)
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--data_size', type=int, default=128*1000)
parser.add_argument('--val_size', type=int, default=5000)
parser.add_argument('--no_timesteps', type=int, default=50)
parser.add_argument('--disc_steps', type=int, default=10)
parser.add_argument('--memory_length', type=int, default=10)
parser.add_argument('--manual_seed', type=int, default=0)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--smoothing_factor', type=float, default=1e-3)
parser.add_argument('--data_set', type=str, default="toy")
parser.add_argument('--start_sig', type=float, default=1)
parser.add_argument('--loss_function', type=str, default="ikl")
parser.add_argument("--subsample_time", type=int, default=10)
parser.add_argument("--equidist", type=str, default="False")
args = parser.parse_args()
args = parser.parse_args()
torch.manual_seed(args.manual_seed)
#data_creater = synthetic_data(args.data_set)
device = "cuda"
no_timesteps = args.no_timesteps
disc_steps = args.disc_steps
memory_length = args.memory_length
short_id = f"{args.path}/{args.data_set}_rho{args.rho}_sig{args.sigma}_sub{args.subsample_time}_loss{args.loss_function}_{args.manual_seed}"
output_path = short_id
os.makedirs(output_path, exist_ok=True)
write_results = WriteResults(args, output_path)



def train_jump_model(net, opti, no_epochs, dataloader,val_set,times_eval):
    best_mmd = 10000
    loss_list = []
    mmd_list = []
    val_size = val_set.shape[0]
    for i in range(no_epochs):
        avg_loss = 0

        for k,x in enumerate(dataloader):
            time = x[:,:,1]
            x = x[:,:,0]
            loss = loss_calc_jump(x,time, net, memory_length, sigma = args.sigma,rho=args.rho,loss_function = args.loss_function)
            opti.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1.)    
            opti.step()
            avg_loss = (k/(k+1))*avg_loss + (1/(k+1))*loss.item()
        with torch.no_grad():
            
            samples = euler(net, val_size, no_timesteps,times_eval, disc_steps, memory_length,sigma = args.sigma,rho=args.rho,initial_std = args.start_sig).cpu().data.numpy()
            eval_trj = torch.tensor(samples[:,::disc_steps].squeeze(2),device=device)
            print("EPOCH" + str(i))
            mmd_val = mmd(val_set,eval_trj).item()
            mmd_list.append(mmd_val)
            print("\033[34m MMD: ", mmd_val, "\033[0m")
            if mmd_val < best_mmd:
                torch.save(net.state_dict(), os.path.join(output_path, "model_jump.pt"))
                best_mmd = mmd_val

                write_results.write_image_traj(samples, disc_steps, 100, "samples_jump_val_mmd_{}_".format(round(best_mmd,2)))
        loss_list.append(avg_loss)

    write_results.plot_loss(loss_list)
    write_results.write_value(best_mmd, "mmd_val")
    np.savetxt(os.path.join(output_path, "mmd_list.txt"), mmd_list)
    return best_mmd

data_creater = load_data(args.manual_seed)
dataloader, val_set, test_set, times_eval = data_creater.get_data(args.subsample_time, args.no_timesteps,batch_size = args.batch_size, device = device)
data = next(iter(dataloader)).cpu()
print(times_eval)
print(data.shape)
data_plot = data[:100,:,0]
print(data[33:35,:,0])
plt.figure()
for k in range(len(data_plot)):
    plt.plot(data_plot[k].cpu().data.numpy())
plt.savefig(os.path.join(output_path, "data_true.png"))

mmd = SamplesLoss("energy")


torch.cuda.manual_seed(args.manual_seed)
torch.cuda.manual_seed_all(args.manual_seed)
net = create_mlp_jump_gauss_3(memory_length, 256).to(device)
opti = torch.optim.Adam(net.parameters(), lr = args.lr)



val_mmd = train_jump_model(net, opti, args.no_epochs, dataloader, val_set, times_eval)
print("FINAL MMD VAL")
print(val_mmd)
net_best_val = create_mlp_jump_gauss_3(memory_length, 256).to(device)
net_best_val.load_state_dict(torch.load(os.path.join(output_path, "model_jump.pt")))
test_size = test_set.shape[0]
samples = euler(net_best_val, test_size, no_timesteps,times_eval, disc_steps, memory_length,sigma = args.sigma,rho=args.rho,initial_std = args.start_sig).cpu().data.numpy()
mmd_test = mmd(test_set,torch.tensor(samples[:,::disc_steps].squeeze(2),device=device)).item()
write_results.write_image_traj(samples, disc_steps, 100, "samples_jump_test_mmd_{}_".format(round(mmd_test,2)))
write_results.write_value(mmd_test, "mmd_test")
print("MMD TEST")
print(mmd_test)


