import torch
# from reg_data import RegDataset
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
from data_gen import icl_reg_Dataset, icl_reg_data, icl_reg_data_batch_wise,\
    icl_reg_batch_wise_Dataset, icl_reg_Dataset, icl_reg_data, icl_NNrelu_Dataset
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from utils import gd_feature_map, hilbert_feature_map
import wandb
import ipdb
import random
import pickle
from tqdm import tqdm
import gc
import argparse
from train_mlp import train_mlp

def split_context(y):
    y_context = torch.zeros_like(y)
    y_context[:, :-1] = y[:, :-1]
    y_test = y[:, -1]
    return y_context,y_test

def combine(xs_b, ys_b):
    """
    Directly stack the x's and y's into the same location
    resulting sequence would be Bx(N+1)x(d+1), where (N+1)-th token is test
    """
    zs = torch.cat((xs_b, ys_b.unsqueeze(2)), dim=2)
    # zs[:, -1, -1].zero_()
    return zs


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

parser = argparse.ArgumentParser(description='Example of using argparse')

parser.add_argument('--T', type=int, default=5000)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--method', type=str, default="mlp")
parser.add_argument('--feats_method', type=str, default="gd")
parser.add_argument('--task', type=str, default="Linear_regression")
parser.add_argument('--dim', type=int, default=8)
parser.add_argument('--dim_eff', type=int, default=None)
parser.add_argument('--depth', type=int, default=2)
parser.add_argument('--max_cl', type=int, default=8)
parser.add_argument('--sigma', type=float, nargs='+', default=[0.22])
parser.add_argument('--contexts', type=int, nargs='+', default=[10,20,30,40])
parser.add_argument('--sparsity', type=int, default=3)
parser.add_argument('--hid_dim_model', type=int, default=100)



args = parser.parse_args()

seed = args.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
random.seed(seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# set-up your own wandb
# wandb_init = {}
# wandb_init["project_name"] = ""
# wandb_init["mode"] = 'online'
# wandb_init["key"] = ""
# wandb_init["org"] = ""
#
# os.environ["WANDB_API_KEY"] = wandb_init['key']
# os.environ["WANDB_MODE"] = wandb_init['mode']  # online or offline
# run = wandb.init(project=wandb_init['project_name'], entity=wandb_init['org'])

run = None




print("generating data ... ")
config_dict = {}
test_path = None

T = args.T
d = args.dim
context_length  = args.max_cl
if args.task == "Linear_regression":
    d_eff = args.dim_eff
    sigmas = args.sigma # std of epsilon
    Sigma_scale = 1 / d
    Sigma = Sigma_scale * np.eye(d)  
    max_context = args.max_cl
    test_dataloaders = []

    train_dataset =  icl_reg_Dataset(sigmas, d, max_context+1, T, d_eff=d_eff)
    test_datasets = []
    for sigma in sigmas:

        if test_path:
            filename = test_path
            filename = base_path + f'test_sets/{args.task}/d{d}_d_eff{d_eff}/sigma=[{sigma}]_cl={max_context}.pkl'
            with open(filename, 'rb') as file:
                test_dataset = pickle.load(file)
        else:
            test_dataset=icl_reg_Dataset(sigmas, d, max_context+1, 10_000, d_eff=d_eff)
        test_datasets.append(test_dataset)

    config_dict["d_eff"] = d_eff
    config_dict["sigmas"] = sigmas


    for ind, test_dataset in enumerate(test_datasets):
        extention_name = f"noise_level={sigmas[ind]}"
        test_dataloaders.append( (DataLoader(test_dataset, batch_size=256, shuffle=False), extention_name) )


    XY_all_test = combine(test_dataset.x, test_dataset.y)
    X_test = test_dataset.x
    y_test = test_dataset.y
    XY_all_train = combine(train_dataset.x, train_dataset.y)
    X_train = train_dataset.x
    y_train = train_dataset.y

elif args.task == "NN_Relu":

    hid_dim_model = args.hid_dim_model
    train_on = list(range(5, 101, 10))
    test_on = list(range(10,101,5))
    max_context = args.max_cl
    train_dataset = icl_NNrelu_Dataset(d, max_context + 1, T, hid_dim=hid_dim_model)

    if test_path:
        filename = test_path
        with open(filename, 'rb') as file:
            test_dataset = pickle.load(file)
    else:
        test_dataset = icl_NNrelu_Dataset(d, max_context + 1, 10_000, hid_dim=hid_dim_model)


    config_dict["hid_dim_model"] = args.hid_dim_model

    extention_name = ''
    test_dataloaders = [(DataLoader(test_dataset, batch_size=256, shuffle=False),extention_name)]
    XY_all_test = combine(test_dataset.x, test_dataset.y)
    XY_all_train = combine(train_dataset.x, train_dataset.y)

elif args.task == "tree":
    from data_gen import DecisionTreeDataset
    train_on = list(range(5, 101, 5))
    test_on = list(range(10,101,10))
    max_context = args.max_cl
    train_dataset = DecisionTreeDataset(max_context+1, d,  T)
    if test_path:
        filename = test_path
        with open(filename, 'rb') as file:
            test_dataset = pickle.load(file)
    else:
        test_dataset = DecisionTreeDataset(max_context+1, d,  10_000)


    config_dict["hid_dim_model"] = args.hid_dim_model
    config_dict["train_on"] = train_on
    extention_name = ''
    test_dataloaders = [(DataLoader(test_dataset, batch_size=256, shuffle=False),extention_name)]


    XY_all_test = combine(test_dataset.xi,test_dataset.fxi)
    XY_all_train = combine(train_dataset.xi, train_dataset.fxi )


elif args.task == "SparseLinear":
    from data_gen import SparseLinearDataset
    max_context = args.max_cl
    s = args.sparsity
    train_dataset = SparseLinearDataset(max_context+1, d, s, T)

    if test_path:
        filename = test_path
        with open(filename, 'rb') as file:
            test_dataset = pickle.load(file)
    else:
        test_dataset =  SparseLinearDataset(max_context+1, d, s, 10_000)

    extention_name = ''
    test_dataloaders = [(DataLoader(test_dataset, batch_size=256, shuffle=False),extention_name)]
    XY_all_test = combine(test_dataset.xi, test_dataset.fx)
    XY_all_train = combine(train_dataset.xi, train_dataset.fx)

print("generating done.")
config_dict['task'] = args.task
config_dict['method'] = args.method
if run:
    run.config.update(config_dict)
name = f'{args.task}: '


contexts = args.contexts

max_cl = contexts[-1]
XY_train_all_cls = []
y_train_all_cls = []
XY_test_all_cls = []
y_test_all_cls = []

for cl in tqdm(contexts):

    mask = torch.ones((T,cl+1,d+1))
    mask[:, -1, -1] = 0  #
    XY_train_tmp = XY_all_train[:,0:cl+1,:]*mask
    y_train_tmp = XY_all_train[:,cl,-1].clone()

    mask = torch.ones((XY_all_test.shape[0], cl + 1, d + 1))
    mask[:, -1, -1] = 0
    XY_test_tmp = XY_all_test[:,0:cl+1,:]*mask
    y_test_tmp = XY_all_test[:,cl,-1].clone()

    if args.feats_method == "hilbert":
        XY_train_feats_tmp = hilbert_feature_map(XY_train_tmp)
        XY_test_feats_tmp  = hilbert_feature_map(XY_test_tmp)
        current_dim = XY_train_feats_tmp.shape[-1]
        full_dim = (args.dim + 1) * (max_cl + 2)
        XY_train_flatten_tmp_zeropad = F.pad(XY_train_feats_tmp, pad=(0, full_dim - current_dim), mode='constant',
                                             value=0)
        XY_test_flatten_tmp_zeropad = F.pad(XY_test_feats_tmp, pad=(0, full_dim - current_dim), mode='constant',
                                            value=0)



    elif args.feats_method == "gd":
        XY_train_feats_tmp = gd_feature_map(XY_train_tmp)
        XY_test_feats_tmp  = gd_feature_map(XY_test_tmp)
        current_dim = XY_train_feats_tmp.shape[-1]
        full_dim = (args.dim + 1) * (max_cl + 2)
        XY_train_flatten_tmp_zeropad = F.pad(XY_train_feats_tmp, pad=(0, full_dim - current_dim), mode='constant',
                                             value=0)
        XY_test_flatten_tmp_zeropad = F.pad(XY_test_feats_tmp, pad=(0, full_dim - current_dim), mode='constant',
                                            value=0)

    elif args.feats_method == "just_flatten":

        XY_train_flatten = torch.flatten(XY_train_tmp, start_dim=1)
        XY_test_flatten = torch.flatten(XY_test_tmp, start_dim=1)
        current_dim = XY_train_flatten.shape[-1]
        full_dim = (args.dim + 1) * (max_cl + 1)
        # ipdb.set_trace()
        XY_train_flatten_tmp_zeropad = F.pad(XY_train_flatten, pad=(0, full_dim - current_dim), mode='constant',
                                             value=0)
        XY_test_flatten_tmp_zeropad = F.pad(XY_test_flatten, pad=(0, full_dim - current_dim), mode='constant',
                                            value=0)

    elif args.feats_method == "gd_feats_plus_flatten":
        XY_train_feats_tmp = gd_feature_map(XY_train_tmp)
        XY_test_feats_tmp  = gd_feature_map(XY_test_tmp)
        XY_train_flatten = torch.cat([torch.flatten(XY_train_tmp, start_dim=1),XY_train_feats_tmp[:,-1,:]], dim=1)
        XY_test_flatten = torch.cat([torch.flatten(XY_test_tmp, start_dim=1),XY_test_feats_tmp[:,-1,:]], dim=1)
        current_dim = XY_train_flatten.shape[-1]
        full_dim = (args.dim + 1) * (max_cl + 2)
        # ipdb.set_trace()
        XY_train_flatten_tmp_zeropad = F.pad(XY_train_flatten, pad=(0, full_dim - current_dim), mode='constant',
                                             value=0)
        XY_test_flatten_tmp_zeropad = F.pad(XY_test_flatten, pad=(0, full_dim - current_dim), mode='constant',
                                            value=0)
    elif args.feats_method == "hilbert_feats_plus_flatten":
        XY_train_feats_tmp = hilbert_feature_map(XY_train_tmp)
        XY_test_feats_tmp  = hilbert_feature_map(XY_test_tmp)
        XY_train_flatten = torch.cat([torch.flatten(XY_train_tmp, start_dim=1),XY_train_feats_tmp[:,-1,:]], dim=1)
        XY_test_flatten = torch.cat([torch.flatten(XY_test_tmp, start_dim=1),XY_test_feats_tmp[:,-1,:]], dim=1)
        current_dim = XY_train_flatten.shape[-1]
        full_dim = (args.dim + 1) * (max_cl + 2)

        XY_train_flatten_tmp_zeropad = F.pad(XY_train_flatten, pad=(0, full_dim - current_dim), mode='constant',
                                             value=0)
        XY_test_flatten_tmp_zeropad = F.pad(XY_test_flatten, pad=(0, full_dim - current_dim), mode='constant',
                                            value=0)



    XY_train_all_cls.append(XY_train_flatten_tmp_zeropad)
    XY_test_all_cls.append(XY_test_flatten_tmp_zeropad)
    y_train_all_cls.append(y_train_tmp)
    y_test_all_cls.append(y_test_tmp)
    gc.collect()

XY_train_all = torch.cat(XY_train_all_cls)
y_train_all = torch.cat(y_train_all_cls)



input_size = XY_train_all.shape[-1]
# Create datasets and dataloaders
train_dataset = TensorDataset(XY_train_all,torch.tensor(y_train_all))
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)

test_dataloader = []
for indcl,cl in enumerate(contexts):
    test_dataset = TensorDataset(XY_test_all_cls[indcl], torch.tensor(y_test_all_cls[indcl]))
    test_dataloader.append([DataLoader(test_dataset, batch_size=32, shuffle=False),cl])

n_layer = 2
hidden_size = 1024
if run:
    run.name = f"{args.task}: MLP_{args.feats_method} -- T:{T} -- dim={d} -- context_length={context_length}" \
               f" -- n_layers:{n_layer} -- hidden_size = {hidden_size}  "

    run.config.update({f'T(number of training sequences)': T,
                       f'n_layers': n_layer,
                       f'dimension:': d,
                       f'context_length': context_length,
                       f'width': hidden_size,
                       'model': "MLP",
                       'seed': args.seed,
                       "contexts": contexts,
                       "feats_method":args.feats_method
                       })
# mlp
train_on = list(range(1, context_length + 1))
train_mlp(train_dataloader, test_dataloader, input_size, run, n_layers=n_layer, hidden_size=hidden_size,
          n_epochs=1000, learning_rate=0.0001)














