import os
import numpy as np
import torch
from torch import optim
import time

from core.utils_GF import load_data, w2
import core.gradient_flow as gradient_flow
from core.TW import generate_trees_frames
import cfg
args = cfg.parse_args()
from tqdm import tqdm
# Configuration
dataset_name = args.dataset_name
nofiterations = args.num_iter
seeds = range(1,args.num_seeds+1)
modes = ['linear', 'linear', 'linear', 'linear', 'linear', 'linear', 'linear', 'linear']
titles = ["TSW-Root-Path", "TSW-Root", 'TSW-SL-orthogonal', 'TSW-SL-distance-based', 'TSW-SL-uniform', 'SW', 'LCV-SW', 'SWGG']
colors = ['blue', 'red', 'green', 'purple', 'brown', 'pink', 'gray', 'orange']

# Arrays to store results
os.makedirs("logs", exist_ok=True)
results = {}
for title in titles:
    results[title] = {'raw_w2': np.zeros((nofiterations, len(seeds)))}

Xs = []
for i, seed in enumerate(seeds):
    np.random.seed(seed)
    N = 500  # Number of samples from p_X
    Xs.append(load_data(name=dataset_name, n_samples=N, dim=2))
    Xs[i] -= Xs[i].mean(dim=0)[np.newaxis, :]  # Normalization
lear_rates = [args.lr_tsw_sl, args.lr_tsw_sl, args.lr_tsw_sl, args.lr_tsw_sl, args.lr_tsw_sl, args.lr_sw, args.lr_sw, args.lr_sw]
n_projs = [int(args.L / args.n_lines), int(args.L / args.n_lines), int(args.L / args.n_lines), int(args.L / args.n_lines), int(args.L / args.n_lines), args.L, args.L, args.L]
assert len(titles) == len(lear_rates) == len(n_projs) == len(modes)

print(f"Data {dataset_name}: {Xs[0].shape}")
for k, title in enumerate(titles):
    for i, seed in enumerate(seeds):
        np.random.seed(seed)
        torch.manual_seed(seed)
        X = Xs[i].detach().clone()
        meanX = 0
        _, d = X.shape

        # Construct folder name based on hyperparameters
        args_dict = vars(args)
        folder_info = '-'.join([f"{key.replace('_', '')}{value}" for key, value in args_dict.items()])

        # Use GPU if available, CPU otherwise
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the initial distribution
        temp = np.random.normal(loc=meanX, scale=.25, size=(N, d))

        # Define the variables to store the loss (2-Wasserstein distance)
        dist = 'w2'
        w2_dist = np.nan * np.zeros((nofiterations))

        # Define the optimizers and gradient flow objects
        Y = torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True)
        optimizer = optim.Adam([Y], lr=lear_rates[k])
        gsw_res = gradient_flow.GF(ftype=modes[k], nofprojections=n_projs[k], device=device)

        # s = len(modes)
        # fig = pl.figure(figsize=(4 * s, 8 + 3))

        mean_X = torch.mean(X, dim=0, keepdim=True).to(device)
        std_X = torch.std(X, dim=0, keepdim=True).to(device)
        # print(w2(X.detach().cpu().numpy(), Y.detach().cpu().numpy()))
        for t in tqdm(range(nofiterations)):
            theta = torch.ones(len(modes), d)

            loss = 0

            if title == 'SW':
                start_time = time.time()  # Start timing
                loss += gsw_res.sw(X.to(device), Y, theta=None)
                end_time = time.time()  # End timing
                # print(f"Time taken for SW: {end_time - start_time:.4f} seconds")

            elif title == 'TSW-SL-distance-based':
                start_time = time.time()  # Start timing
                theta_twd, intercept_twd = generate_trees_frames(
                    ntrees=int(args.L / args.n_lines),
                    nlines=args.n_lines,
                    d=X.shape[1],
                    mean=mean_X,
                    std=args.std,
                    gen_mode='gaussian_raw',
                    device='cuda'
                )  # distance_based
                loss += gradient_flow.TWD(X=X.to(device), Y=Y, theta=theta_twd, intercept=intercept_twd, mass_division='distance_based', p=args.p, delta=args.delta)
                end_time = time.time()  # End timing
                # print(f"Time taken for TWD distance based: {end_time - start_time:.4f} seconds")

            elif title == 'TSW-SL-uniform':
                start_time = time.time()  # Start timing
                theta_twd, intercept_twd = generate_trees_frames(
                    ntrees=int(args.L / args.n_lines),
                    nlines=args.n_lines,
                    d=X.shape[1],
                    mean=mean_X,
                    std=0.01,
                    gen_mode='gaussian_raw',
                    device='cuda'
                )  # uniform
                loss += gradient_flow.TWD(X=X.to(device), Y=Y, theta=theta_twd, intercept=intercept_twd, mass_division='uniform', p=args.p)
                end_time = time.time()  # End timing
                # print(f"Time taken for TWD uniform: {end_time - start_time:.4f} seconds")

            elif title == 'TSW-SL-orthogonal':
                start_time = time.time()  # Start timing
                theta_twd, intercept_twd = generate_trees_frames(
                    ntrees=int(args.L / args.n_lines),
                    nlines=args.n_lines,
                    d=X.shape[1],
                    mean=mean_X,
                    std=args.std,
                    gen_mode='gaussian_orthogonal',
                    device='cuda'
                )  # orthogonal
                loss += gradient_flow.TWD(X=X.to(device), Y=Y, theta=theta_twd, intercept=intercept_twd, mass_division='distance_based', p=args.p, delta=args.delta)
                end_time = time.time()  # End timing
                # print(f"Time taken for TWD orthogonal: {end_time - start_time:.4f} seconds")

            elif title == 'LCV-SW':
                start_time = time.time()  # Start timing
                loss += gradient_flow.LCVSW(X.to(device), Y.to(device), L=args.L)
                end_time = time.time()  # End timing
                # print(f"Time taken for LCVSW: {end_time - start_time:.4f} seconds")

            elif title == 'SWGG':
                start_time = time.time()  # Start timing
                l, theta = gsw_res.SWGG_CP(X.to(device), Y.to(device), theta=None)
                loss += l
                end_time = time.time()  # End timing
                # print(f"Time taken for SWGG_CP: {end_time - start_time:.4f} seconds")

            elif title == "TSW-Root":
                start_time = time.time()  # Start timing
                theta_twd, intercept_twd = generate_trees_frames(
                    ntrees=int(args.L / args.n_lines),
                    nlines=args.n_lines,
                    d=X.shape[1],
                    mean=mean_X,
                    std=args.std,
                    gen_mode='gaussian_raw',
                    intercept_mode='geometric_median',
                    X=Y,
                    Y=X, # X is target distribution
                    device='cuda'
                )  # orthogonal
                loss += gradient_flow.TWD(X=X.to(device), Y=Y, theta=theta_twd, intercept=intercept_twd, mass_division='distance_based', p=args.p, delta=args.delta)
                end_time = time.time()  # End timing
                # print(f"Time taken for TSW-Root: {end_time - start_time:.4f} seconds")

            elif title == "TSW-Root-Path":
                start_time = time.time()  # Start timing
                w = ((args.num_iter - t - 1)/(args.num_iter - 1)) ** args.beta
                kappa = args.kappa*w + args.kappa2*(1-w)
                theta_twd, intercept_twd = generate_trees_frames(
                    ntrees=int(args.L / args.n_lines),
                    nlines=args.n_lines,
                    d=X.shape[1],
                    mean=mean_X,
                    std=args.std,
                    gen_mode='random_path',
                    intercept_mode='geometric_median',
                    X=Y,
                    Y=X, # X is target distribution
                    kappa=kappa,
                    device='cuda'
                )  # orthogonal
                loss += gradient_flow.TWD(X=X.to(device), Y=Y, theta=theta_twd, intercept=intercept_twd, mass_division='distance_based', p=args.p, delta=args.delta)
                end_time = time.time()  # End timing
                # print(f"Time taken for TSW-Root-Path: {end_time - start_time:.4f} seconds")

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if dist == 'w2' and (t + 1) % 500 == 0:
                w2_dist[t] = w2(X.detach().cpu().numpy(), Y.detach().cpu().numpy())
                
        results[title]['raw_w2'][:, i] = w2_dist
        print(w2_dist[-1])

    # Save results
    with open(f"logs/{title}_resutls.txt", "a") as f:
        f.write(f"lr-{lear_rates[k]} nproj-{n_projs[k]} data-{dataset_name} nsampl-{N}\n") 
        stp = [0, 499, 999, 1499, 1999, 2499]
        for step in stp:
            data = results[title]['raw_w2'][step]
            f.write(f"\t\t{data.mean()} +- {data.std()}\n")
        f.write("\n")
