import argparse
import os
import torch
from data.ghg import getGHG
from data.gas import getGas
from data.electric import getElectric

from evaluate import getbest_regression
from pathlib import Path
import sys
import time
from misc_utils import *
import warnings
from tqdm import tqdm
import numpy as np
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer


lasso = 1

def make_parser_reg():
    parser = argparse.ArgumentParser()

    def aa(*args, **kwargs):
        parser.add_argument(*args, **kwargs)

    aa("--data", type=str, default="ghg", help="ghg|gas|electric")
    aa("--m", type=int, default=10, help="m for S")
    aa("--iter", type=int, default=1000, help="total iterations")

    aa("--random", default=False, action='store_true', help="don't learn S! Just compute error on random S")

    aa("--size", type=int, default=-1, help="dataset size")
    aa("--lr", type=float, default=1.0, help="learning rate for gradient descent")
    aa("--raw", dest='raw', default=False, action='store_true', help="generate raw?")
    aa("--bestonly", dest='bestonly', default=False, action='store_true', help="only compute best?")
    aa("--device", type=str, default="cuda:0")

    aa("--k_sparse", type=int, default=1, help="number of values in a column of S, sketching mat")
    aa("--num_exp", type=int, default=1, help="number of times to rerun the experiment (for avg'ing results)")
    aa("--bs", type=int, default=1, help="batch size")
    aa("--initalg", type=str, default="random", help="random|kmeans|lev|gs|lev_cluster|load")
    aa("--load_file", type=str, default="", help="if initalg=load, where to get S?")

    aa("--save_fldr", type=str,
       help="folder to save experiment results into; if None, then general folder")  # default: None
    aa("--save_file", type=str, help="append to runtype, if not None")

    aa("--S_init_method", type=str, default="pm1", help="pm1|gaussian|gaussian_pm1")
    return parser


if __name__ == '__main__':
    runtype = "train_regression"
    initalg_name2fn_dict = {"kmeans": init_w_kmeans, "lev": init_w_lev, "gs": init_w_gramschmidt,
                            "lev_cluster": init_w_lev_cluster, "load": init_w_load}

    parser = make_parser_reg()
    args = parser.parse_args()

    rawdir = "/home/me/research/big-regression/" if get_hostname() == "owner-ubuntu" else "C:\\Users\\pass3010e\\Desktop\\learned_sketch-master\\lowrank\\data"
    rltdir = "/home/me/research/big-regression/" if get_hostname() == "owner-ubuntu" else "C:\\Users\\pass3010e\\Desktop\\12"

    print(args)
    m = args.m

    if args.data == 'ghg':
        save_dir_prefix = os.path.join(rltdir, "rlt", "ghg")
    elif args.data == 'gas':
        save_dir_prefix = os.path.join(rltdir, "rlt", "gas")
    elif args.data == 'electric':
        save_dir_prefix = os.path.join(rltdir, "rlt", "electric")

    elif args.data == 'newgas':
        save_dir_prefix = os.path.join(rltdir, "rlt", "newgas")

    elif args.data == 'co':
        save_dir_prefix = os.path.join(rltdir, "rlt", "co")

    elif args.data == 'gaussian':
        save_dir_prefix = os.path.join(rltdir, "rlt", "gaussian")

    elif args.data == 'number':
        save_dir_prefix = os.path.join(rltdir, "rlt", "number")

    elif args.data == 'w4a':
        save_dir_prefix = os.path.join(rltdir, "rlt", "w4a")

    elif args.data == 'flocking':
        save_dir_prefix = os.path.join(rltdir, "rlt", "flocking")

    else:
        print("Wrong data option!")
        sys.exit()

    if args.save_file:
        runtype = runtype + "_" + args.save_file
    if args.save_fldr:
        save_dir = os.path.join(save_dir_prefix, args.save_fldr, args_to_fldrname(runtype, args))
    else:
        save_dir = os.path.join(save_dir_prefix, args_to_fldrname(runtype, args))

    best_fl_save_dir = os.path.join(save_dir_prefix, "best_files")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if not os.path.exists(best_fl_save_dir):
        os.makedirs(best_fl_save_dir)

    if (not args.bestonly) and (len(os.listdir(save_dir))):
        print("This experiment is already done! Now exiting.")
        sys.exit()

    lr = args.lr
    if args.data == "ghg":
        AB_train, AB_test, n, d_a, d_b = getGHG(args.raw, 500, rawdir, 100) # args.size
        A_train = AB_train[0]
        B_train = AB_train[1]
        A_test = AB_test[0]
        B_test = AB_test[1]
        # A_train = torch.randn(size=(100, 300000, 10))
        # B_train = torch.randn(size=(100, 300000, 1))
        # A_test = torch.randn(size=(20, 300000, 10))
        # B_test = torch.randn(size=(20, 300000, 1))
        # n = 300000
        # d_a = 10
        # d_b = 1
    elif args.data == "gas":
        AB_train, AB_test, n, d_a, d_b = getGas(args.raw, rawdir, 100)
        A_train = AB_train[0]
        B_train = AB_train[1]
        A_test = AB_test[0]
        B_test = AB_test[1]
    elif args.data == "electric":
        AB_train, AB_test, n, d_a, d_b = getElectric(args.raw, rawdir, 100)
        A_train = AB_train[0]
        B_train = AB_train[1]
        A_test = AB_test[0]
        B_test = AB_test[1]

    elif args.data == "newgas":
        AB_train, AB_test, n, d_a, d_b = getnewGas(args.raw, rawdir, 100)
        A_train = AB_train[0]
        B_train = AB_train[1]
        A_test = AB_test[0]
        B_test = AB_test[1]

    elif args.data == "co":

        A_train, B_train = torch.load("C:\\Users\\pass3010e\\Desktop\\learned_lhs\\lowrank\\data\\raw\\co\\train96.dat")
        A_test, B_test = torch.load("C:\\Users\\pass3010e\\Desktop\\learned_lhs\\lowrank\\data\\raw\\co\\test24.dat")

        n = 300
        d_a = 9
        d_b = 1

    elif args.data == "gaussian":

        A_train = torch.load("C:\\Users\\pass3010e\\Desktop\\learned_lhs\\lowrank\\data\\raw\\gaussian\\train270.dat")
        A_test = torch.load("C:\\Users\\pass3010e\\Desktop\\learned_lhs\\lowrank\\data\\raw\\gaussian\\test30.dat")

        B_train = torch.zeros(270, 1030, 1)
        B_test = torch.zeros(30, 1030, 1)
        n = 1030
        d_a = 30
        d_b = 1

    # elif args.data == "gaussian":
    #
    #     A_train = torch.load("C:\\Users\\pass3010e\\Desktop\\learned_lhs\\lowrank\\data\\raw\\gaussian\\train170.dat")
    #     A_test = torch.load("C:\\Users\\pass3010e\\Desktop\\learned_lhs\\lowrank\\data\\raw\\gaussian\\test50.dat")
    #     #print("ttttt", A_train.size())
    #     B_train = torch.zeros(170, 300, 1)
    #     B_test = torch.zeros(50, 300, 1)
    #     n = 300
    #     d_a = 10
    #     d_b = 1

    elif args.data == "number":

        A_train = torch.load("C:\\Users\\pass3010e\\Desktop\\learned_lhs\\lowrank\\data\\raw\\number\\train160.dat")
        A_test = torch.load("C:\\Users\\pass3010e\\Desktop\\learned_lhs\\lowrank\\data\\raw\\number\\test40.dat")

        B_train = torch.zeros(160, 5030, 1)
        B_test = torch.zeros(40, 5030, 1)
        n = 5030
        d_a = 30
        d_b = 1

    elif args.data == "w4a":

        A_train = torch.load("C:\\Users\\pass3010e\\Desktop\\learned_lhs\\lowrank\\data\\raw\\w4a\\train300.dat")
        A_test = torch.load("C:\\Users\\pass3010e\\Desktop\\learned_lhs\\lowrank\\data\\raw\\w4a\\test60.dat")

        B_train = torch.zeros(300, 300, 1)
        B_test = torch.zeros(60, 300, 1)
        n = 300
        d_a = 20
        d_b = 1

    elif args.data == "flocking":

        A_train = torch.load("C:\\Users\\pass3010e\\Desktop\\learned_lhs\\lowrank\\data\\raw\\swarm\\train160.dat")
        A_test = torch.load("C:\\Users\\pass3010e\\Desktop\\learned_lhs\\lowrank\\data\\raw\\swarm\\test40.dat")
        # print(A_train)
        B_train = torch.zeros(160, 2430, 1)
        B_test = torch.zeros(40, 2430, 1)
        n = 2430
        d_a = 30
        d_b = 1

    print("A_train", A_train.size())
    print("B_train", B_train.size())
    print("Working on data ", args.data)

    Path(save_dir).mkdir(parents=True, exist_ok=True)

    N_train = len(A_train)
    N_test = len(A_test)
    print("Dim= ", n, d_a, d_b)
    print("N train=", N_train, "N test=", N_test)

    best_file = os.path.join(best_fl_save_dir, "N=" + str(args.size) + '_best')
    if (not os.path.isfile(best_file)) or args.raw:
        print("computing optimal least squares (linear) approximations", best_file)
        getbest_regression(A_train, B_train, A_test, B_test, best_file)

    best_train, best_test = torch.load(best_file)
    print("Best: %f , %f" % (best_train, best_test))
    # print("ln 125 in train_regression; check if best errors are properly loaded")
    # IPython.embed()

    start = time.time()
    print_freq = 50

    # save args
    args_save_fpath = os.path.join(save_dir, "args_it_0.pkl")
    f = open(args_save_fpath, "wb")
    pickle.dump(vars(args), f)
    f.close()

    avg_over_exps = 0

    _solution = []

    for exp_num in range(args.num_exp):

        #A_train = A_train.to(args.device)
        #B_train = B_train.to(args.device)
        # if exp_num == 0:
        #     exp_num = 1 + 1
        it_save_dir = os.path.join(save_dir, "exp_%d" % exp_num)
        it_print_freq = print_freq
        it_lr = lr
        print(exp_num)
        if not os.path.exists(it_save_dir):
            os.makedirs(it_save_dir)

        test_errs = []
        train_errs = []
        fp_times = []
        bp_times = []

        # Initialize sparsity pattern
        if args.initalg == "random":
            sketch_vector = torch.randint(m, [args.k_sparse, n]).int()
            sketch_vector = torch.randint(m, [args.k_sparse, n]).int()

        elif args.initalg == "load":
            initalg = initalg_name2fn_dict[args.initalg]
            sketch_vector, sketch_value_cpu, active_ind = initalg(args.load_file, exp_num, n, m)
            sketch_value = sketch_value_cpu.detach().to(args.device)

        # Note: we sample with repeats, so you may map 1 row of A to <k_sparse distinct locations
        sketch_vector.requires_grad = False
        if args.initalg != "load":
            if args.S_init_method == "pm1":
                sketch_value = ((torch.randint(2, [args.k_sparse, n]).float() - 0.5) * 2).to(args.device)
                # sketch_value = ((torch.randint(2, [args.k_sparse, n]).float() - 0.5) * 2).to(args.device)
            elif args.S_init_method == "gaussian":
                sketch_value = torch.from_numpy(np.random.normal(size=[args.k_sparse, n]).astype("float32")).to(
                    args.device)
                # sketch_value_cpu = ((torch.randint(2, [args.k_sparse, n]).float() - 0.5) * 2).to(args.device)
            elif args.S_init_method == "gaussian_pm1":
                sketch_value = ((torch.randint(2, [args.k_sparse, n]).float() - 0.5) * 2).to(args.device)
                sketch_value = sketch_value + torch.from_numpy(
                    np.random.normal(size=[args.k_sparse, n]).astype("float32")).to(args.device)

        sketch_value.requires_grad = True

        if exp_num == 0:
            prev_xh = torch.zeros((A_train.size()[0], A_train.size()[2], 1)).to(args.device)
            itt = 100
        else:

            prev_xh = torch.load("solution" + str(exp_num - 1) + ".dat")

        for bigstep in tqdm(range(args.iter)):
            if (bigstep % 1000 == 0) and it_lr > 1:
                it_lr = it_lr * 0.3
            if bigstep > 200:
                it_print_freq = 200

            fp_start_time = time.time()

            batch_rand_ind = np.random.randint(0, high=N_train, size=args.bs)
            #batch_rand_ind = np.random.randint(0, high=N_train, size=args.bs)
            AM = A_train[batch_rand_ind].to(args.device)
            BM = B_train[batch_rand_ind].to(args.device)
            prev_x = prev_xh[batch_rand_ind].to(args.device)

            S = torch.zeros(m, n).to(args.device)
            S[sketch_vector.type(torch.LongTensor).reshape(-1), torch.arange(n).repeat(
                    args.k_sparse)] = sketch_value.reshape(-1)

            #print("S", S.size())
            # if bigstep % it_print_freq == 0 or bigstep == (args.iter - 1):
            #     # IPython.embed()
            #     train_err, test_err = save_iteration_regression(S, A_train, B_train, A_test, B_test, it_save_dir, bigstep)
            #     train_errs.append(train_err)
            #     test_errs.append(test_err)
            #     if bigstep == (args.iter - 1):
            #         avg_over_exps += (test_err/args.num_exp)
            #     if args.random:
            #         # don't train! after evaluating, exit trial
            #         break

            SA = torch.matmul(S, AM)


            ans = torch.zeros(args.bs, AM.size()[2], BM.size()[2]).to(args.device)

            x = cp.Variable((AM.size()[2], 1)) # BM.size()[2]  + cp.norm(x, 1)
            constraints = [x >= 0, cp.sum(x) == 1, cp.norm(x, 1) <= 1]
            _SA = cp.Parameter((m, AM.size()[2]))
            _SB = cp.Parameter((m, BM.size()[2]))
            _b = cp.Parameter((AM.size()[2], 1))
            _prev = cp.Parameter((AM.size()[2], 1))
            objective = cp.Minimize(0.5 * cp.sum_squares(_SA @ x - _SB) - cp.sum(_b.T @ x))
            problem = cp.Problem(objective, constraints)
            assert problem.is_dpp()
            b = torch.matmul(AM.permute(0, 2, 1), (BM - torch.matmul(AM, prev_x)))
            SB = SA @ prev_x
            cvxpylayer = CvxpyLayer(problem, parameters=[_SA, _b, _SB], variables=[x])
            solution, = cvxpylayer(SA, b, SB)
            _Ax = torch.matmul(AM, solution)
            ATbx = torch.matmul(b.permute(0, 2, 1), solution)

            loss = torch.mean(0.5 * (torch.norm(_Ax - torch.matmul(AM, prev_x), dim=(1,2), p = 'fro')**2) - ATbx.squeeze() )

            print(loss) # ().size()
            # print("fp: ", time.time() -fp_start_time) + torch.norm(solution, dim = 1, p = 1) + cp.norm(x, 1)  - torch.norm(ATbx, dim = 1, p = 1)
            fp_times.append(time.time() - fp_start_time)
            bp_start_time = time.time()
            loss.backward()
            # print("bp: ", time.time() -bp_start_time)
            bp_times.append(time.time() - bp_start_time)

            with torch.no_grad():
                if args.initalg == "load":
                    sketch_value[active_ind] -= (it_lr / args.bs) * sketch_value.grad[active_ind]
                    sketch_value.grad.zero_()
                else:
                    sketch_value -= (it_lr / args.bs) * sketch_value.grad

                    sketch_value.grad.zero_()

            #del SA, SB, U, Sig, V, X, ans, loss
            torch.cuda.empty_cache()

        S = torch.zeros(m, n).to(args.device)
        S[sketch_vector.type(torch.LongTensor).reshape(-1), torch.arange(n).repeat(
            args.k_sparse)] = sketch_value.reshape(-1)

        print("sketch matrix size", S.size())

        #print(S)

        _A_train = A_train.to(args.device)
        _B_train = B_train.to(args.device)

        x = cp.Variable((AM.size()[2], 1))
        _SA = cp.Parameter((m, AM.size()[2]))
        _SB = cp.Parameter((m, BM.size()[2]))
        _b = cp.Parameter((AM.size()[2], 1))
        _prev = cp.Parameter((AM.size()[2], 1))
        objective = cp.Minimize(0.5 * cp.sum_squares(_SA @ x - _SB) - (_b.T @ x))
        constrains = [x >= 0, cp.sum(x) == 1, cp.norm(x, 1) <= 1]

        problem = cp.Problem(objective, constrains)
        assert problem.is_dpp()

        SA = torch.matmul(S, _A_train)
        b = torch.matmul(_A_train.permute(0, 2, 1), _B_train - torch.matmul(_A_train, prev_xh))
        SB = SA @ prev_xh
        cvxpylayer = CvxpyLayer(problem, parameters=[_SA, _SB, _b], variables=[x])
        solution, = cvxpylayer(SA, SB, b)

        torch.save(solution, "solution" + str(exp_num) +".dat" )
        _solution.append(solution)


        S = S.cpu().detach().numpy()
        np.save("1number" + str(exp_num), S)

        # np.save(os.path.join(it_save_dir, "train_errs.npy"), train_errs, allow_pickle=True)
        # np.save(os.path.join(it_save_dir, "test_errs.npy"), test_errs, allow_pickle=True)
        # np.save(os.path.join(it_save_dir, "fp_times.npy"), fp_times, allow_pickle=True)
        # np.save(os.path.join(it_save_dir, "bp_times.npy"), bp_times, allow_pickle=True)
    print(avg_over_exps)