import torch
from train import train_model
from data_gen import icl_reg_Dataset, icl_reg_data, icl_NNrelu_Dataset
from torch.utils.data import Dataset, DataLoader
import wandb
import os
import numpy as np
import torch.nn as nn
import random
import pickle
import argparse



# Create the parser
parser = argparse.ArgumentParser()

# Add arguments
parser.add_argument('--task', type=str, default="Linear_regression")
parser.add_argument('--dim', type=int, default=10)
parser.add_argument('--dim_eff', type=int, default=None)
parser.add_argument('--max_cl', type=int, default=20)
parser.add_argument('--model', type=str, default='gpt2')
parser.add_argument('--n_layers', type=int, default=1)
parser.add_argument('--width', type=int, default=256)
parser.add_argument('--T', type=int, default=5_000)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--sigma', type=float, nargs='+', default=[0.2])
parser.add_argument('--train_on', type=int, nargs='+', default=[20])
parser.add_argument('--test_on', type=int, nargs='+', default=[20])
parser.add_argument('--n_heads', type=int, default=1)
parser.add_argument('--hid_dim_model', type=int, default=100)
parser.add_argument('--sparsity', type=int, default=3)
parser.add_argument('--read_in_fixed', action='store_false')

args = parser.parse_args()


random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

#device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# test set path
test_path = None

#data set
T = args.T
d = args.dim
train_on = args.train_on
test_on  = args.test_on

config_dict = {f'T(number of tasks)': args.T,
                f'n_layers': args.n_layers,
                f'n_heads': args.n_heads,
                f'hid_dim': args.width,
                f'model': args.model,
                f'Maximum context length': args.max_cl,
                f'seed': args.seed,
                f'train_on':train_on,
                f'test_on':test_on,
                }

print("generating data ... ")
if args.task == "Linear_regression":
    d_eff = args.dim_eff
    sigmas = args.sigma # std of epsilon
    Sigma_scale = 1 / d
    Sigma = Sigma_scale * np.eye(d)  # Covariance matrix for beta (example for d=3)
    max_context = args.max_cl
    test_dataloaders = []

    train_dataset =  icl_reg_Dataset(sigmas, d, max_context+1, T, d_eff=d_eff)
    test_datasets = []
    for sigma in sigmas:

        if test_path:
            filename = test_path
            with open(filename, 'rb') as file:
                test_dataset = pickle.load(file)
        else:
            test_dataset =  icl_reg_Dataset([sigma], d, max_context+1, 10_000, d_eff=d_eff)
        test_datasets.append(test_dataset)

    run_name = f'model={args.model} -- width={args.width} -- n_layers ={args.n_layers}' \
    f' -- T={args.T} -- max_cl={max_context} -- dimension{d}-- dimension_eff{d_eff}' \
    f' -- sigma{sigmas} -- n_heads={args.n_heads}'
    config_dict["d_eff"] = d_eff
    config_dict["sigmas"] = sigmas

    for ind, test_dataset in enumerate(test_datasets):
        extention_name = f"noise_level={sigmas[ind]}"
        test_dataloaders.append( (DataLoader(test_dataset, batch_size=256, shuffle=False), extention_name) )

    if len(sigmas)>1:
        wandb_name = args.task + "_multi_noise"
    else:
        wandb_name = args.task +"_single_noise"

elif args.task == "NN_Relu":

    hid_dim_model = args.hid_dim_model
    max_context = args.max_cl
    train_dataset = icl_NNrelu_Dataset(d, max_context + 1, T, hid_dim=hid_dim_model)

    if test_path:
        filename = test_path
        with open(filename, 'rb') as file:
            test_dataset = pickle.load(file)
    else:
        test_dataset = icl_NNrelu_Dataset(d, max_context + 1, 10_000, hid_dim=hid_dim_model)


    run_name = f'model={args.model} -- width={args.width} -- n_layers ={args.n_layers}' \
    f' -- T={args.T} -- max_cl={max_context} -- dimension{d}' \
    f' -- hid_dim_model{args.hid_dim_model} -- n_heads={args.n_heads}'
    config_dict["hid_dim_model"] = args.hid_dim_model

    extention_name = ''
    test_dataloaders = [(DataLoader(test_dataset, batch_size=256, shuffle=False),extention_name)]
    wandb_name = args.task

elif args.task == "tree":
    from data_gen import DecisionTreeDataset
    max_context = args.max_cl
    train_dataset = DecisionTreeDataset(max_context+1, d,  T)

    if test_path:
        filename = test_path
        with open(filename, 'rb') as file:
            test_dataset = pickle.load(file)
    else:
        test_dataset = DecisionTreeDataset(max_context + 1, d, 10_000)


    run_name = f'model={args.model} -- width={args.width} -- n_layers ={args.n_layers}' \
    f' -- T={args.T} -- max_cl={max_context} -- dimension{d}' \
    f' -- hid_dim_model{args.hid_dim_model} -- n_heads={args.n_heads}'

    config_dict["hid_dim_model"] = args.hid_dim_model
    extention_name = ''
    test_dataloaders = [(DataLoader(test_dataset, batch_size=256, shuffle=False),extention_name)]
    wandb_name = args.task

elif args.task == "SparseLinear":
    from data_gen import SparseLinearDataset
    max_context = args.max_cl
    s = args.sparsity
    train_dataset = SparseLinearDataset(max_context+1, d, s, T)

    if test_path:
        filename = test_path
        with open(filename, 'rb') as file:
            test_dataset = pickle.load(file)
    else:
        test_dataset = SparseLinearDataset(max_context + 1, d, s, 10_000)


    run_name = f'model={args.model} -- width={args.width} -- n_layers ={args.n_layers}' \
    f' -- T={args.T} -- max_cl={max_context} -- dimension{d}' \
    f' -- hid_dim_model{args.hid_dim_model} -- n_heads={args.n_heads}'

    config_dict["hid_dim_model"] = args.hid_dim_model
    extention_name = ''
    test_dataloaders = [(DataLoader(test_dataset, batch_size=256, shuffle=False),extention_name)]
    wandb_name = args.task


dataloader_train = DataLoader(train_dataset, batch_size=256, shuffle=True)


print("data ready ... ")
print(f'T is :{T} -- dimension:{d}')


#wandb
#uncomment this and set your own wandb settings

# wandb_init = {}
# wandb_init["project_name"] = wandb_name
# wandb_init["mode"] = 'online'
# wandb_init["key"] = ""
# wandb_init["org"] = ""
# os.environ["WANDB_API_KEY"] = wandb_init['key']
# run = wandb.init(project=wandb_init['project_name'], entity=wandb_init['org'])
#
#
#
# #wandb config
# run.config.update({f'dimension': f'{d}',
#                    f'T(number of training sequences)': f'{T}',
#                    f"read_in_fixed":args.read_in_fixed,
#                    })
#
# run.name = run_name

run = None

#Loss
criterion = nn.MSELoss(reduction='mean')

#Model
n_layers = args.n_layers
hid_dim = args.width


if args.model == 'gpt2':
    from models.gpt2 import TransformerModel
    model = TransformerModel(d+1 , 101, n_embd=hid_dim, n_layer=n_layers,n_head = args.n_heads)
elif args.model == 'sgpt':
    from models.sgpt import sgpt
    model = sgpt(d+1, hid_dim, n_layers=n_layers, n_heads = args.n_heads,
                         read_in_fixed = args.read_in_fixed)




for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}: {param.numel()}")

model.to(device)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"TF Total number of parameters: {total_params}")
config_dict['model trainable params'] = total_params




#train model
epochs = 1500

if run:
    run.config.update(config_dict)


train_model(model,criterion,dataloader_train, test_dataloaders,device,run,
            n_epochs= epochs, train_on = train_on, test_on=  test_on )


