import numpy as np
import torch
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
import sys
from args import parse_args
import os 
import datetime

def one_hot_encoding(x, bs, len, cls) :
    if np.ndim(x) == 3 : return torch.from_numpy(x).to(torch.float32)
    res = F.one_hot(torch.LongTensor(np.reshape(x, -1)), cls)
    return res.view(bs, len, cls).to(torch.float32)

def load_data(args, data, label, b_inds) :
    if isinstance(data, np.ndarray) == False : 
        return data[b_inds].to(args.device), label[b_inds].to(args.device)
    input, target = one_hot_encoding(data[b_inds], b_inds.shape[0], args.length + 1, args.action_n), one_hot_encoding(label[b_inds], b_inds.shape[0], args.length + 1, args.action_n)
    return input.to(args.device), target.to(args.device)

def generate_data(args, environ, len) : # return a dataset of tensor shape: data = (args.num_traj, len + 1, args.action_n), label = (args.num_traj, len + 1, *)
    args.num_traj = (args.num_traj - 1) // args.num_envs * args.num_envs + args.num_envs
    if args.num_traj <= 1000000 : # store data on GPU if possible
        data = torch.zeros(args.num_traj, len + 1, args.input_size, dtype = torch.float32).to(args.device)
    else :
        data = torch.zeros(args.num_traj, len + 1, args.input_size, dtype = torch.float32)
    data[:,:,-1] = torch.ones(args.num_traj, len + 1)
    # positional embedding
    data[:,:,-2] = torch.cos(torch.tile(torch.arange(0, len + 1, dtype = torch.float32), (args.num_traj, 1)) * torch.pi * 0.25 / args.nn_max_len)
    data[:,:,-3] = torch.sin(torch.tile(torch.arange(0, len + 1, dtype = torch.float32), (args.num_traj, 1)) * torch.pi * 0.25 / args.nn_max_len)
    data[:,0,-4] = torch.ones(args.num_traj)

    if args.num_traj <= 1000000 :
        label = torch.zeros(args.num_traj, len + 1, args.output_size, dtype = torch.float32).to(args.device)
    else :
        label = torch.zeros(args.num_traj, len + 1, args.output_size, dtype = torch.float32)
    for t in range(0, args.num_traj, args.num_envs) :
        s, _ = environ.reset()
        label[t : t + args.num_envs, 0] = torch.from_numpy(environ.emit_obs_distribution())
        for step in range(len) :
            action = environ.emit_obs(real = args.real)
            if args.env not in ('LinearRegression', 'LinearDynamicalSystem') : 
                data[t : t + args.num_envs, step + 1, :args.action_n] = F.one_hot(torch.LongTensor(action), args.action_n)
            else :
                data[t : t + args.num_envs, step + 1, :args.action_n] = torch.FloatTensor(action)
            s, _ = environ.step(action)
            label[t : t + args.num_envs, step + 1] = torch.from_numpy(environ.emit_obs_distribution())

    
    if args.env != 'AutoRegression' :
        return data, label
    else :
        return label, label[:,1:,:]

if __name__ == "__main__" :
    args = parse_args()
    print(os.path.exists(args.dataset))
    pp = pd.read_csv(args.dataset)