import os
import numpy as np
import torch
from torch import optim
import matplotlib.pyplot as pl
import time
import wandb
import shutil
import sys
sys.path.append('../code/')

from utils_GF import load_data, w2
import gradient_flow
from TW import generate_trees_frames
import cfg
args = cfg.parse_args()
from tqdm import tqdm
# Configuration
dataset_name = args.dataset_name
nofiterations = args.num_iter
seeds = range(1,args.num_seeds+1)
modes = ['linear', 'linear', 'linear', 'linear', 'linear', 'linear']
titles = ['SW', 'TSW-SL-distance-based', 'TSW-SL-uniform', 'TSW-SL-orthorgonal','LCVSW', 'SWGG']

# Arrays to store results
sw_results = np.zeros((nofiterations, len(seeds)))
tsw_sl_distance_based_results = np.zeros((nofiterations, len(seeds)))
tsw_sl_uniform_results = np.zeros((nofiterations, len(seeds)))
tsw_sl_orthogonal_results = np.zeros((nofiterations, len(seeds)))
lcvsw_results = np.zeros((nofiterations, len(seeds)))
swgg_results = np.zeros((nofiterations, len(seeds)))



for seed in seeds:
    N = 100  # Number of samples from p_X
    X = load_data(name=dataset_name, n_samples=N, dim=2)
    X -= X.mean(dim=0)[np.newaxis, :]  # Normalization

    meanX = 0
    _, d = X.shape

    # Construct folder name based on hyperparameters
    args_dict = vars(args)
    folder_info = '-'.join([f"{key.replace('_', '')}{value}" for key, value in args_dict.items()])
    results_folder = f"./Results/Gradient_Flow_{folder_info}_seed{seed}"
    os.makedirs(results_folder, exist_ok=True)
    # if not os.path.isdir(results_folder):
    #     os.mkdir(results_folder)
    # if os.path.exists(results_folder):
    #     shutil.rmtree(results_folder)

    foldername = os.path.join(results_folder, 'Gifs', dataset_name + '_Comparison')
    os.makedirs(foldername, exist_ok=True)

    # Use GPU if available, CPU otherwise
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    lear_rates = [args.lr_sw, args.lr_tsw_sl, args.lr_tsw_sl, args.lr_tsw_sl, args.lr_sw, args.lr_sw]
    n_proj = [args.L, int(args.L / args.n_lines), int(args.L / args.n_lines), int(args.L / args.n_lines), args.L, args.L]

    # Define the initial distribution
    temp = np.random.normal(loc=meanX, scale=.25, size=(N, d))

    # Define the variables to store the loss (2-Wasserstein distance)
    dist = 'w2'
    w2_dist = np.nan * np.zeros((nofiterations, len(modes)))

    # Define the optimizers and gradient flow objects
    Y = []
    optimizer = []
    gsw_res = []

    for k in range(len(modes)):
        Y.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))
        optimizer.append(optim.Adam([Y[k]], lr=lear_rates[k]))
        gsw_res.append(gradient_flow.GF(ftype=modes[k], nofprojections=n_proj[k], device=device))

    s = len(modes)
    fig = pl.figure(figsize=(4 * s, 8 + 3))

    mean_X = torch.mean(X, dim=0, keepdim=True).to(device)
    std_X = torch.std(X, dim=0, keepdim=True).to(device)

    for i in tqdm(range(nofiterations)):
        loss = []
        theta = torch.ones(len(modes), d)

        for k in range(s):
            loss_ = 0
            if k == 0:
                start_time = time.time()  # Start timing
                loss_ += gsw_res[k].sw(X.to(device), Y[k], theta=None)
                end_time = time.time()  # End timing
                elapsed_time = end_time - start_time
                # print(f"Time taken for SW: {elapsed_time:.4f} seconds")

            if k == 1:
                start_time = time.time()  # Start timing
                theta_twd, intercept_twd = generate_trees_frames(ntrees=int(args.L / args.n_lines), nlines=args.n_lines, d=X.shape[1], mean=mean_X, std=args.std, gen_mode='gaussian_raw', device='cuda')  # distance_based
                loss_ += gradient_flow.TWD(X=X.to(device), Y=Y[k], theta=theta_twd, intercept=intercept_twd, mass_division='distance_based', p=args.p, delta=args.delta)
                end_time = time.time()  # End timing
                elapsed_time = end_time - start_time
                # print(f"Time taken for TWD distance based: {elapsed_time:.4f} seconds")

            if k == 2:
                start_time = time.time()  # Start timing
                theta_twd, intercept_twd = generate_trees_frames(ntrees=int(args.L / args.n_lines), nlines=args.n_lines, d=X.shape[1], mean=mean_X, std=args.std, gen_mode='gaussian_raw', device='cuda')  # uniform
                loss_ += gradient_flow.TWD(X=X.to(device), Y=Y[k], theta=theta_twd, intercept=intercept_twd, mass_division='uniform', p=args.p)
                end_time = time.time()  # End timing
                elapsed_time = end_time - start_time
                # print(f"Time taken for TWD uniform: {elapsed_time:.4f} seconds")

            if k == 3:
                start_time = time.time()  # Start timing
                theta_twd, intercept_twd = generate_trees_frames(ntrees=int(args.L / args.n_lines), nlines=args.n_lines, d=X.shape[1], mean=mean_X, std=args.std, gen_mode='gaussian_orthogonal', device='cuda')  # orthogonal
                loss_ += gradient_flow.TWD(X=X.to(device), Y=Y[k], theta=theta_twd, intercept=intercept_twd, mass_division='distance_based', p=args.p, delta=args.delta)
                end_time = time.time()  # End timing
                elapsed_time = end_time - start_time
                # print(f"Time taken for TWD orthogonal: {elapsed_time:.4f} seconds")

            if k == 4:
                start_time = time.time()  # Start timing
                loss_ += gradient_flow.LCVSW(X.to(device), Y[k].to(device), L=args.L)
                end_time = time.time()  # End timing
                elapsed_time = end_time - start_time
                # print(f"Time taken for LCVSW: {elapsed_time:.4f} seconds")

            if k == 5:
                start_time = time.time()  # Start timing
                l, theta[k], _ = gsw_res[k].get_minSWGG_smooth(X.to(device), Y[k].to(device), s=100, std=0.5,init=theta[k])
                loss_ += l
                end_time = time.time()  # End timing
                elapsed_time = end_time - start_time
                # print(f"Time taken for SWGG: {elapsed_time:.4f} seconds")

            loss.append(loss_)
            optimizer[k].zero_grad()
            loss[k].backward()
            optimizer[k].step()

            if dist == 'w2':
                w2_dist[i, k] = w2(X.detach().cpu().numpy(), Y[k].detach().cpu().numpy())

        if i == nofiterations - 1:
            np.savetxt(f"{results_folder}/{dataset_name}_SW_seed{seed}.txt", w2_dist[:, 0])
            np.savetxt(f"{results_folder}/{dataset_name}_TSW_SL_distance_based_seed{seed}.txt", w2_dist[:, 1])
            np.savetxt(f"{results_folder}/{dataset_name}_TSW_SL_uniform_seed{seed}.txt", w2_dist[:, 2])
            np.savetxt(f"{results_folder}/{dataset_name}_TSW_SL_orthogonal_seed{seed}.txt", w2_dist[:, 3])
            np.savetxt(f"{results_folder}/{dataset_name}_LCV_SW_seed{seed}.txt", w2_dist[:, 4])
            np.savetxt(f"{results_folder}/{dataset_name}_SWGG_seed{seed}.txt", w2_dist[:, 5])


    sw_results[:, seed-1] = w2_dist[:, 0]
    tsw_sl_distance_based_results[:, seed-1] = w2_dist[:, 1]
    tsw_sl_uniform_results[:, seed-1] = w2_dist[:, 2]
    tsw_sl_orthogonal_results[:, seed-1] = w2_dist[:, 3]
    lcvsw_results[:, seed-1] = w2_dist[:, 4]
    swgg_results[:, seed-1] = w2_dist[:, 5]


sw_mean = np.mean(sw_results, axis=1)
tsw_sl_distance_based_mean = np.mean(tsw_sl_distance_based_results, axis=1)
tsw_sl_uniform_mean = np.mean(tsw_sl_uniform_results, axis=1)
tsw_sl_orthogonal_mean = np.mean(tsw_sl_orthogonal_results, axis=1)
lcvsw_mean = np.mean(lcvsw_results, axis=1)
swgg_mean = np.mean(swgg_results, axis=1)

np.savetxt(f"{results_folder}/{dataset_name}_SW_mean.txt", sw_mean)
np.savetxt(f"{results_folder}/{dataset_name}_TSW_SL_distance_based_mean.txt", tsw_sl_distance_based_mean)
np.savetxt(f"{results_folder}/{dataset_name}_TSW_SL_uniform_mean.txt", tsw_sl_uniform_mean)
np.savetxt(f"{results_folder}/{dataset_name}_TSW_SL_orthogonal_mean.txt", tsw_sl_orthogonal_mean)
np.savetxt(f"{results_folder}/{dataset_name}_LCV_SW_mean.txt", lcvsw_mean)
np.savetxt(f"{results_folder}/{dataset_name}_SWGG_mean.txt", swgg_mean)


#Calculate the mean and standard deviation for each iteration
sw_results_log_10 = np.log10(sw_results)
tsw_sl_distance_based_results_log_10 = np.log10(tsw_sl_distance_based_results)
tsw_sl_uniform_results_log_10 = np.log10(tsw_sl_uniform_results)
tsw_sl_orthogonal_results_log_10 = np.log10(tsw_sl_orthogonal_results)
lcvsw_results_log_10 = np.log10(lcvsw_results)
swgg_results_log_10 = np.log10(swgg_results)

sw_mean_log = np.mean(sw_results_log_10, axis=1)
sw_std_log = np.std(sw_results_log_10, axis=1)

tsw_sl_distance_based_mean_log = np.mean(tsw_sl_distance_based_results_log_10, axis=1)
tsw_sl_uniform_mean_log = np.mean(tsw_sl_uniform_results_log_10, axis=1)
tsw_sl_orthogonal_mean_log = np.mean(tsw_sl_orthogonal_results_log_10, axis=1)
lcvsw_mean_log = np.mean(lcvsw_results_log_10, axis=1)
swgg_mean_log = np.mean(swgg_results_log_10, axis=1)

tsw_sl_distance_based_std_log = np.std(tsw_sl_distance_based_results_log_10, axis=1)
tsw_sl_uniform_std_log = np.std(tsw_sl_uniform_results_log_10, axis=1)
tsw_sl_orthogonal_std_log = np.std(tsw_sl_orthogonal_results_log_10, axis=1)
lcvsw_std_log = np.std(lcvsw_results_log_10, axis=1)
swgg_std_log = np.std(swgg_results_log_10, axis=1)


# Plot the results
pl.figure(figsize=(10, 6))

# Plot SW with mean and shaded standard deviation (log scale)
pl.plot(sw_mean_log, label='SW', color='blue')
pl.fill_between(range(nofiterations), sw_mean_log - sw_std_log, sw_mean_log + sw_std_log, color='blue', alpha=0.2)

# Plotting for tsw_sl_distance_based_results
pl.plot(tsw_sl_distance_based_mean_log, label='TSW-SL-distance-based', color='orange')
pl.fill_between(range(nofiterations), tsw_sl_distance_based_mean_log - tsw_sl_distance_based_std_log, tsw_sl_distance_based_mean_log + tsw_sl_distance_based_std_log, color='orange', alpha=0.2)

# Plotting for tsw_sl_uniform_results
pl.plot(tsw_sl_uniform_mean_log, label='TSW-SL-uniform', color='red')
pl.fill_between(range(nofiterations), tsw_sl_uniform_mean_log - tsw_sl_uniform_std_log, tsw_sl_uniform_mean_log + tsw_sl_uniform_std_log, color='red', alpha=0.2)

# Plotting for tsw_sl_orthogonal_results
pl.plot(tsw_sl_orthogonal_mean_log, label='TSW-SL-orthogonal', color='green')
pl.fill_between(range(nofiterations), tsw_sl_orthogonal_mean_log - tsw_sl_orthogonal_std_log, tsw_sl_orthogonal_mean_log + tsw_sl_orthogonal_std_log, color='green', alpha=0.2)

# Plotting for lcvsw_results
pl.plot(lcvsw_mean_log, label='LCVSW', color='blue')
pl.fill_between(range(nofiterations), lcvsw_mean_log - lcvsw_std_log, lcvsw_mean_log + lcvsw_std_log, color='blue', alpha=0.2)


# Plotting for swgg_results
pl.plot(swgg_mean_log, label='SWGG', color='brown')
pl.fill_between(range(nofiterations), swgg_mean_log - swgg_std_log, swgg_mean_log + swgg_std_log, color='brown', alpha=0.2)

# Add text box with argument information
# Prepare the text box content without dataset_name
args_info = [f'{key.replace("_", " ").capitalize()}: {value}' for key, value in args_dict.items()]

# Join the list into a single string with newline separation
textstr = '\n'.join(args_info)

# Place a text box with argument information
pl.gca().text(0.05, 0.95, textstr, transform=pl.gca().transAxes, fontsize=10,
              verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# Finalize the plot with dataset name in the title
pl.title(f'Log Wasserstein Distance over 5 Runs ({dataset_name})')
pl.xlabel('Iterations')
pl.ylabel(r'$W_2$ Distance (log scale)')
pl.legend()
pl.grid(True)

plot_filename = os.path.join(results_folder, f'{folder_info}_log.png')
pl.savefig(plot_filename)
pl.clf()

wandb.init(project="twd_gradient_flow", name = folder_info)
wandb.config.update(args)

log_dict = {
    'SW ': sw_mean[args.num_iter - 1],
    'TSW-SL-distance-based': tsw_sl_distance_based_mean[args.num_iter - 1],
    'TSW-SL-uniform': tsw_sl_uniform_mean[args.num_iter - 1],
    'TSW-SL-orthogonal': tsw_sl_orthogonal_mean[args.num_iter - 1],
    'LCVSW': lcvsw_mean[args.num_iter - 1],
    'SWGG': swgg_mean[args.num_iter - 1],
}
wandb.log(log_dict)

wandb.log({'Image': [wandb.Image(plot_filename)]})

wandb.finish()
