import torch
import numpy as np
from data import *
from network import *
from sde_utils import *
import matplotlib.pyplot as plt
from geomloss import SamplesLoss

from write_results import *
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--rho', type=float, default=0.03)
parser.add_argument('--sigma', type=float, default=0.3)
parser.add_argument('--path', type=str, default="sde_rho")
parser.add_argument('--no_epochs', type=int, default=300)
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--data_size', type=int, default=128*1000)
parser.add_argument('--val_size', type=int, default=5000)
parser.add_argument('--no_timesteps', type=int, default=50)
parser.add_argument('--disc_steps', type=int, default=10)
parser.add_argument('--memory_length', type=int, default=10)
parser.add_argument('--manual_seed', type=int, default=0)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--smoothing_factor', type=float, default=1e-3)
parser.add_argument('--data_set', type=str, default="toy")
parser.add_argument('--start_sig', type=float, default=1)
parser.add_argument('--loss_function', type=str, default="ikl")
parser.add_argument("--subsample_time", type=int, default=10)
parser.add_argument("--equidist", type=str, default="False")
args = parser.parse_args()
print(args.equidist)
torch.manual_seed(args.manual_seed)
#data_creater = synthetic_data(args.data_set)
device = "cuda"
no_timesteps = args.no_timesteps
disc_steps = args.disc_steps
memory_length = args.memory_length

short_id = f"{args.path}/{args.data_set}_rho{args.rho}_sig{args.sigma}_sub{args.subsample_time}_loss{args.loss_function}_{args.manual_seed}"
output_path = short_id
os.makedirs(output_path, exist_ok=True)
write_results = WriteResults(args, output_path)
print(write_results.path)
device = 'cuda'

no_epochs = args.no_epochs 
batch_size = args.batch_size
data_size = args.data_size
val_size = args.val_size
no_timesteps = args.no_timesteps
subsample_time = args.subsample_time
disc_steps = args.disc_steps
memory_length = args.memory_length
sigma = args.sigma
rho = args.rho


def train_sde_model(net, opti, no_epochs, dataloader,val_set,times_eval):
    best_mmd = 10000
    loss_list = []
    mmd_list = []    
    for i in range(no_epochs):
        avg_loss = 0

        for k,x in enumerate(dataloader):
            time = x[:,:,1]
            x = x[:,:,0]
            loss = loss_calc_sde(x, time, net, memory_length, sigma = sigma, rho = rho)

            opti.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1.)    
            opti.step()
            avg_loss = (k/(k+1))*avg_loss + (1/(k+1))*loss.item()


        
        with torch.no_grad():
            samples = euler(net, val_size, no_timesteps,times_eval, disc_steps, memory_length, sigma = sigma, rho = rho).cpu().data.numpy()

            eval_trj = torch.tensor(samples[:,::disc_steps].squeeze(2),device=device)
            print("EPOCH" + str(i))
            mmd_val = mmd(val_set,eval_trj).item()
            mmd_list.append(mmd_val)
            print("\033[34m MMD: ", mmd_val, "\033[0m")
            if mmd_val < best_mmd:
                torch.save(net.state_dict(), os.path.join(output_path, "model_jump.pt"))
                best_mmd = mmd_val

                write_results.write_image_traj(samples, disc_steps, 100, "samples_jump_val_mmd_{}_".format(round(best_mmd,2)))
        loss_list.append(avg_loss)
    
    write_results.plot_loss(loss_list)
    write_results.write_value(best_mmd, "mmd_val")
    np.savetxt(os.path.join(output_path, "mmd_list.txt"), mmd_list)

    return best_mmd

data_creater = load_data(args.manual_seed)
dataloader, val_set, test_set, times_eval = data_creater.get_data(args.subsample_time, args.no_timesteps,batch_size = args.batch_size, device = device)
data = next(iter(dataloader)).cpu()
data_plot = data[:100,:]
print(data[33:35,:,0])
mmd = SamplesLoss("energy")


torch.cuda.manual_seed(args.manual_seed)
torch.cuda.manual_seed_all(args.manual_seed)
net = create_mlp_3(memory_length, 256,out=1).to(device)
opti = torch.optim.Adam(net.parameters(), lr = args.lr)
mmd_val = train_sde_model(net, opti, no_epochs, dataloader, val_set, times_eval)
print("FINAL MMD VAL")
print(mmd_val)
"""
net_best_val = create_mlp_3(memory_length, 256,out=1).to(device)
net_best_val.load_state_dict(torch.load(write_results.path +  "/model_jump.pt"))

test_size = test_set.shape[0]
samples = euler(net_best_val, test_size, no_timesteps,times_eval, disc_steps, memory_length,sigma = args.sigma,rho=args.rho,initial_std = args.start_sig).cpu().data.numpy()
mmd_test = mmd(test_set,torch.tensor(samples[:,::disc_steps].squeeze(2),device=device)).item()
write_results.write_image_traj(samples, disc_steps, 100, "samples_test_mmd_{}_".format(round(mmd_test,2)))
write_results.write_value(mmd_test, "mmd_test")
print("MMD TEST:")
print(mmd_test)
"""
