import torch
import sys
import argparse
import ot
import time

import scipy.io as sio
import numpy as np
import pandas as pd

from tqdm.auto import trange

from sliceduot.sliced_uot import reweighted_sliced_ot, sliced_unbalanced_ot
from sliced_opt import reprocess_support, opt_cost_from_plans, opt_plans_64
# from sliceduot.sliced_uot2 import reweighted_sliced_ot2, sliced_unbalanced_ot2

# from dataset import get_movie_review
from mUOT import muot


device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"

parser = argparse.ArgumentParser()
parser.add_argument("--loss", type=str, default="sw", help="Which loss to use")
parser.add_argument("--dataset", type=str, default="Twitter", help="Which dataset to use")
parser.add_argument("--n_projs", type=int, default=500, help="Number of projections")
parser.add_argument("--inner_iter", type=int, default=50, help="Number of inner iter of suot or rsot")
# parser.add_argument("--rho1", type=float, default=1, help="rho1")
# parser.add_argument("--rho2", type=float, default=1, help="rho2")
parser.add_argument("--reg_sinkhorn", type=float, default=0.1, help="Epsilon sinkhorn")
parser.add_argument("--pbar", action="store_true", help="If yes, plot pbar")
# parser.add_argument("--draw_once", action="store_true", help="If yes, draw once the projections for RSW")
parser.add_argument("--unnormalize", action="store_false", help="If yes, does not normalize measures")
parser.add_argument("--njobs", type=int, default=5, help="Number of jobs in parallel")
parser.add_argument("--ntry", type=int, default=5, help="Number of try")
parser.add_argument("--size_batch", type=int, default=50, help="Size of batchs")
parser.add_argument("--num_batch", type=int, default=10, help="Number of batchs")
args = parser.parse_args()


def compute_ot(i):
#     print("i launched", i, device, flush=True)
    L = range(i+1, len(X_train))

    for j in L:
        x1 = torch.tensor(X_train[i], device=device, dtype=torch.float64).T
        w1 = torch.tensor(w_train[i], device=device, dtype=torch.float64)[0]

        x2 = torch.tensor(X_train[j], device=device, dtype=torch.float64).T
        w2 = torch.tensor(w_train[j], device=device, dtype=torch.float64)[0]

        t = time.time()
        x1, x2 = reprocess_support(w1, x1), reprocess_support(w2, x2)
        sopt_dist, _, _, _ = opt_plans_64(x1, x2, np.array([rho] * args.n_projs))
        loss = torch.tensor([sopt_dist])
        ts.append(time.time()-t)


        dist_mat[i, j] = loss.item()
        dist_mat[j, i] = loss.item()
    
#     print(str(i) + " done")


if __name__ == "__main__":
    
    print(device, args.loss, flush=True)
    
    n_projs = args.n_projs
    
    if args.dataset == "Twitter":
        mat_contents = sio.loadmat("./data/twitter-emd_tr_te_split.mat")
        
        X = mat_contents["X"][0]
        w = mat_contents["BOW_X"][0]
        
    elif args.dataset == "BBC":
        mat_contents = sio.loadmat("./data/bbcsport-emd_tr_te_split.mat")

        X = mat_contents["X"][0]
        w = mat_contents["BOW_X"][0]

    elif args.dataset == "movie":
#         X, w, _ = get_movie_review()
        X = np.load("./data/X_movie.npy", allow_pickle=True)
        w = np.load("./data/w_movie.npy", allow_pickle=True)
        
    elif args.dataset == "goodreads":
        X = np.load("./data/X_goodread.npy", allow_pickle=True)
        w = np.load("./data/w_goodread.npy", allow_pickle=True)
        
    n_try = args.ntry
        
    for rho in [0.0001, 0.0005, 0.001, 0.005, 0.001, 0.01, 0.1, 1.0]:
        for k in range(n_try):
            X_train = X
            w_train = w
        
            if args.pbar:
                pbar = trange(len(X_train))
            else:
                pbar = range(len(X_train))
                
            dist_mat = np.zeros((len(X_train), len(X_train)))
            ts = []
                                    

            for i in pbar:
                for j in range(i+1, len(X_train)):
                    x1 = torch.tensor(X_train[i], device=device, dtype=torch.float64).T
                    w1 = torch.tensor(w_train[i], device=device, dtype=torch.float64)[0]

                    x2 = torch.tensor(X_train[j], device=device, dtype=torch.float64).T
                    w2 = torch.tensor(w_train[j], device=device, dtype=torch.float64)[0]

                    # t = time.time()
                    x1, x2 = reprocess_support(w1, x1), reprocess_support(w2, x2)
                    sopt_dist, _, _, _ = opt_plans_64(x1, x2, np.array([rho] * args.n_projs))
                    loss = torch.tensor([sopt_dist])
                    # ts.append(time.time()-t)


                    dist_mat[i, j] = loss.item()
                    dist_mat[j, i] = loss.item()
            
            if (args.loss == "rsw" or args.loss == "stochastic_rsw" or args.loss == "suw" or args.loss == "sopt") and args.unnormalize:
                np.savetxt("./results_"+str(args.dataset)+"/d_projs"+str(n_projs)+"_"+args.loss+"_unnormalize_"+ \
                        args.dataset+"_rho1"+str(rho)+"_rho2"+str(rho)+"_k"+str(k), dist_mat)
                # np.savetxt("./results_time/ts_projs"+str(n_projs)+"_"+args.loss+"_unnormalize_"+args.dataset+"_rho1"+ \
                #         str(rho)+"_rho2"+str(rho)+"_k"+str(k), ts)
            
        