import argparse
import os
import sys
import time
from pathlib import Path

import numpy as np
import torch
from data.electric import getElectric
from data.gas import getGas
from data.ghg import getGHG
from evaluate import getbest_regression
from misc_utils import *
from tqdm import tqdm


def make_parser_reg():
    parser = argparse.ArgumentParser()

    def aa(*args, **kwargs):
        parser.add_argument(*args, **kwargs)

    aa("--data", type=str, default="ghg", help="ghg|gas|electric")
    aa("--m", type=int, default=10, help="m for S")
    aa("--iter", type=int, default=1000, help="total iterations")

    aa("--random", default=False, action='store_true', help="don't learn S! Just compute error on random S")

    aa("--lr", type=float, default=1.0, help="learning rate for gradient descent")
    aa("--raw", dest='raw', default=False, action='store_true', help="generate raw?")
    aa("--device", type=str, default="cuda:0")

    aa("--k_sparse", type=int, default=1, help="number of values in a column of S, sketching mat")
    aa("--num_exp", type=int, default=1, help="number of times to rerun the experiment (for avg'ing results)")
    aa("--bs", type=int, default=1, help="batch size")
    # aa("--initalg", type=str, default="random", help="random|kmeans|lev|gs|lev_cluster|load")
    aa("--initalg", type=str, default="random", help="random|kmeans|lev|gs|lev_cluster|load")
    # aa("--load_file", type=str, default="", help="if initalg=load, where to get S?")

    # aa("--save_fldr", type=str,
    #    help="folder to save experiment results into; if None, then general folder")  # default: None
    # aa("--save_file", type=str, help="append to runtype, if not None")

    aa("--S_init_method", type=str, default="pm1", help="pm1|gaussian|gaussian_pm1")
    return parser


if __name__ == '__main__':
    runtype = "train_regression"
    initalg_name2fn_dict = {"kmeans": init_w_kmeans, "lev": init_w_lev, "gs": init_w_gramschmidt,
                            "lev_cluster": init_w_lev_cluster, "load": init_w_load}

    parser = make_parser_reg()
    args = parser.parse_args()

    rawdir = "/home/me/research/big-regression/" if get_hostname() == "owner-ubuntu" else "data"
    rltdir = "/home/me/research/big-regression/" if get_hostname() == "owner-ubuntu" else ""

    print(args)
    m = args.m

    if args.data == 'ghg':
        save_dir_prefix = os.path.join(rltdir, "rlt", "ghg")
    elif args.data == 'gas':
        save_dir_prefix = os.path.join(rltdir, "rlt", "gas")
    elif args.data == 'electric':
        save_dir_prefix = os.path.join(rltdir, "rlt", "electric")

    elif args.data == 'newele':
        save_dir_prefix = os.path.join(rltdir, "rlt", "newele")

    else:
        print("Wrong data option!")
        sys.exit()


    lr = args.lr
    if args.data == "ghg":
        AB_train, AB_test, n, d_a, d_b = getGHG(args.raw, 500, rawdir, 100)
        A_train = AB_train[0]
        B_train = AB_train[1]
        A_test = AB_test[0]
        B_test = AB_test[1]

    elif args.data == "gas":
        AB_train, AB_test, n, d_a, d_b = getGas(args.raw, rawdir, 100)
        A_train = AB_train[0]
        B_train = AB_train[1]
        A_test = AB_test[0]
        B_test = AB_test[1]

    elif args.data == "electric":
        AB_train, AB_test, n, d_a, d_b = getElectric(args.raw, rawdir, 100)
        A_train = AB_train[0]
        B_train = AB_train[1]
        A_test = AB_test[0]
        B_test = AB_test[1]

    # elif args.data == "newele":
    #
    #
    #     A_train, B_train = torch.load("newtrain_400.dat")
    #     A_test, B_test = torch.load("newtest_400.dat")
    #     n = 370
    #     d_a = 9
    #     d_b = 1

    print("Working on data ", args.data)

    N_train = len(A_train)
    N_test = len(A_test)
    print("Dim= ", n, d_a, d_b)
    print("N train=", N_train, "N test=", N_test)
    print(A_train[0].size())
    print(B_train[0].size())
    start = time.time()
    print_freq = 50
    it_print_freq = 50
    for exp_num in range(args.num_exp):

        it_lr = lr
        # Initialize sparsity pattern
        if args.initalg == "random":

            sketch_vector = torch.randint(m, [args.k_sparse, n]).int()

        elif args.initalg == "load":
            initalg = initalg_name2fn_dict[args.initalg]
            sketch_vector, sketch_value_cpu, active_ind = initalg(args.load_file, exp_num, n, m)
            sketch_value = sketch_value_cpu.detach().to(args.device)

        # Note: we sample with repeats, so you may map 1 row of A to <k_sparse distinct locations
        sketch_vector.requires_grad = False

        if args.initalg != "load":
            if args.S_init_method == "pm1":
                sketch_value = ((torch.randint(2, [args.k_sparse, n]).float() - 0.5) * 2).to(args.device)
            elif args.S_init_method == "gaussian":
                sketch_value = torch.from_numpy(np.random.normal(size=[args.k_sparse, n]).astype("float32")).to(
                    args.device)
                # sketch_value_cpu = ((torch.randint(2, [args.k_sparse, n]).float() - 0.5) * 2).to(args.device)
            elif args.S_init_method == "gaussian_pm1":
                sketch_value = ((torch.randint(2, [args.k_sparse, n]).float() - 0.5) * 2).to(args.device)
                sketch_value = sketch_value + torch.from_numpy(
                    np.random.normal(size=[args.k_sparse, n]).astype("float32")).to(args.device)

        sketch_value.requires_grad = True
        # print("ske", sketch_value.size())
        for bigstep in tqdm(range(args.iter)):
            # if (bigstep % 1000 == 0) and it_lr > 1:
            #     it_lr = it_lr * 0.3
            # if bigstep > 200:
            #     it_print_freq = 200

            fp_start_time = time.time()

            batch_rand_ind = np.random.randint(0, high=N_train, size=args.bs)

            AM = A_train[batch_rand_ind].to(args.device)
            BM = B_train[batch_rand_ind].to(args.device)

            S = torch.zeros(m, n).to(args.device)
            S[sketch_vector.type(torch.LongTensor).reshape(-1), torch.arange(n).repeat(
                    args.k_sparse)] = sketch_value.reshape(-1)

            SA = torch.matmul(S, AM)

            Q, iR = torch.qr(SA)
            U, Sig, V = torch.svd(iR)
            R = V.matmul(torch.diag_embed(1.0 / Sig)).matmul(U.permute(0, 2, 1))
            AR = AM.matmul(R)
            siz = AR.size()
            Id = torch.zeros(args.bs, siz[2], siz[2]).to(args.device)

            for p in range(args.bs):
                Id[p] = torch.eye(siz[2], siz[2])

            ARRA = AR.permute(0, 2, 1).matmul(AR) - Id # I

            loss = 0
            for i in range(args.bs):
                loss += torch.linalg.norm(ARRA[i])
            # print(loss
            loss.backward(retain_graph=True)

            if bigstep % it_print_freq == 0:
                print(loss)

            with torch.no_grad():
                if args.initalg == "load":
                    sketch_value[active_ind] -= (it_lr / args.bs) * sketch_value.grad[active_ind]
                    sketch_value.grad.zero_()
                else:
                    sketch_value -= (it_lr / args.bs) * sketch_value.grad
                    sketch_value.grad.zero_()

            # del SA, SB, U, Sig, V, X, ans, loss
            # torch.cuda.empty_cache()

        AM = A_train.to(args.device)
        BM = B_train.to(args.device)
        SA = torch.matmul(S, AM)

        S = S.cpu().detach().numpy()
        np.save("save_name" + str(exp_num), S) # save the trained sketch matrix
