import os
import random

import numpy as np
import models
import torch.optim as optim
import torch
import torch.nn as nn
import math

device = 'cpu'

def convert_data(path='datafiles/', model_type='xgboost', buffers=False, eplen=10, explen=1000):
    explen = int((explen / eplen) * (eplen-1))
    bufferfiles = []
    datanames = []
    for i in os.listdir(path):
        if os.path.isfile(os.path.join(path, i)) and model_type in i:
            if (buffers and 'buffer' in i) or 'exploration' in i:
                filename = os.path.join(path, i)
                bufferfiles.append(filename)
                spacelocs = [i for i, x in enumerate(filename) if x == '_'][-1]
                dataname = filename[spacelocs+1:-4]
                datanames.append(dataname)

    print(bufferfiles)
    print(datanames)

    _, y_cat = np.unique(datanames, return_inverse=True)
    buffer = []
    buffer_y = []
    all_data = []
    all_data_y = []
    getlen = np.vectorize(len)

    for i in range(len(bufferfiles)):
        file = np.load(bufferfiles[i], allow_pickle=True)
        if 'buffer' in bufferfiles[i]:
            states = file[explen:, 0]
        else:
            states = []
            for state in file:
                if state == None:
                    continue
                else:
                    states.append(state[0])

            states = np.asarray(states)


        fullstates_idx = np.squeeze(np.argwhere(getlen(states) == eplen))
        fullstates = states[fullstates_idx]
        fullstates = np.concatenate(fullstates)


        if buffers:
            buffer.append(states)
            buffer_y.append([y_cat[i]] * len(states))

        all_data.append(fullstates)
        all_data_y.append([y_cat[i]] * len(fullstates))

    if buffers:
        buffer = np.concatenate(buffer)
        np.save('prepared_data/sequences_{}_X'.format(model_type), buffer)
        buffer_y = np.concatenate(buffer_y)
        np.save('prepared_data/sequences_{}_Y'.format(model_type), buffer_y)

    all_data = np.concatenate(all_data)
    np.save('prepared_data/fulldata_{}_X'.format(model_type), all_data)
    all_data_y = np.concatenate(all_data_y)
    np.save('prepared_data/fulldata_{}_Y'.format(model_type), all_data_y)

    return

def make_batches(path='prepared_data/', n_obs=5000, maxlen=10, model_type='xgboost'):

    seqfile = False

    for i in os.listdir(path):
        if os.path.isfile(os.path.join(path, i)) and model_type in i:
            filename = os.path.join(path, i)
            if 'fulldata' in filename:
                if 'Y' in filename:
                    datafile_y = filename
                else:
                    datafile_x = filename
            elif 'sequences' in filename:
                seqfile = True
                if 'Y' in filename:
                    seqfile_y = filename
                else:
                    seqfile_x = filename

    batch_size = int(n_obs / maxlen)
    data_x = np.load(datafile_x, allow_pickle=True)
    data_y = np.load(datafile_y, allow_pickle=True)
    y_cats = np.unique(data_y)
    batches_x = {}
    batches_y = {}
    getlen = np.vectorize(len)

    if seqfile:
        seqs_x = np.load(seqfile_x, allow_pickle=True)
        seqs_y = np.load(seqfile_y, allow_pickle=True)

    for y in y_cats:
        y_idx_data = np.squeeze(np.argwhere(data_y == y))
        data_x_y = data_x[y_idx_data]

        if seqfile:
            y_idx_seqs = np.squeeze(np.argwhere(seqs_y == y))
            seqs_x_y = seqs_x[y_idx_seqs]

        for i in range(maxlen):
            if i == 0:
                continue
            seqlen = i+1
            batch_idx = np.random.choice(len(data_x_y), (batch_size, seqlen))

            if seqfile:
            #add sequences
                seq_idx = np.squeeze(np.argwhere(getlen(seqs_x_y) == seqlen))

                seqs_x_seqlen = seqs_x_y[seq_idx]
                seqs_x_seqlen_arr = np.zeros((len(seqs_x_seqlen), seqlen, len(seqs_x_seqlen[0][0])))
                for j in range(len(seqs_x_seqlen)):
                    seqs_x_seqlen_arr[j] = seqs_x_seqlen[j]


                if not seqlen in batches_x.keys():
                    batches_x[seqlen] = np.concatenate([batch, seqs_x_seqlen_arr])
                    batches_y[seqlen] = [y] * (batch_size + len(seqs_x_seqlen_arr))
                else:
                    batches_x[seqlen] = np.concatenate([batches_x[seqlen], batch, seqs_x_seqlen_arr])
                    batches_y[seqlen] = batches_y[seqlen] + ([y] * (batch_size + len(seqs_x_seqlen_arr)))

            else:
                if not seqlen in batches_x.keys():
                    batches_x[seqlen] = batch
                    batches_y[seqlen] = [y] * batch_size
                else:
                    batches_x[seqlen] = np.concatenate([batches_x[seqlen], batch])
                    batches_y[seqlen] = batches_y[seqlen] + ([y] * batch_size)

    return batches_x, batches_y


def main(input_dim=11, output_dim=15, convert=True, buffer=False, batch_size=50, hidden_size=256, epochs=1000,
         test_size=0.2, savename='EmbedNet', eplen=10, model='xgboost'):
    network = models.DatasetClassifier(in_dim=input_dim, hid_dim=hidden_size, out_dim=output_dim)
    loss_calc = nn.CrossEntropyLoss()

    optimizer = optim.Adam([{'params': network.parameters()}], lr=1e-3, weight_decay=1e-7, eps=1e-3)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=list(range(400, epochs, 400)),
                                                    gamma=0.1)

    if convert:
        convert_data(buffers=buffer, model_type=model)

    data_x, data_y = make_batches(maxlen=eplen, model_type=model)
    datalist_x = list(data_x.values())
    datalist_y = list(data_y.values())
    train_batches_x = []
    train_batches_y = []
    test_data_x = []
    test_data_y = []

    #loop over seq lengths
    for i in range(len(datalist_x)):
        x_data = np.asarray(datalist_x[i])
        y_data = np.asarray(datalist_y[i])
        n_datapoints = len(x_data)
        batches = math.ceil(int(n_datapoints*(1-test_size))/batch_size)
        data_order = list(range(n_datapoints))
        random.shuffle(data_order)

        #shuffle data
        x_data = x_data[data_order]
        y_data = y_data[data_order]

        test_idx = int(n_datapoints * test_size)
        x_test = x_data[:test_idx]
        y_test = y_data[:test_idx]
        x_train = x_data[test_idx:]
        y_train = y_data[test_idx:]


        test_data_x.append(x_test)
        test_data_y.append(y_test)

        for j in range(batches):
            if (j+1) * batch_size >= n_datapoints:
                batch_x = x_train[j*batch_size:]
                batch_y = y_train[j*batch_size:]

            else:
                batch_x = x_train[j*batch_size: (j+1)*batch_size]
                batch_y = y_train[j*batch_size: (j+1)*batch_size]

            train_batches_x.append(batch_x)
            train_batches_y.append(batch_y)

    n_batches = len(train_batches_x)

    max_acc = 0
    max_count = 0
    for i in range(epochs):

        sum_acc = 0
        n = 0
        batch_order = list(range(n_batches))
        random.shuffle(batch_order)


        for j in range(n_batches):
            batch_idx = batch_order[j]
            n += len(train_batches_x[batch_idx])
            X = torch.from_numpy(train_batches_x[batch_idx]).float().to(device)
            Y = torch.from_numpy(train_batches_y[batch_idx]).long().to(device)
            optimizer.zero_grad()
            y_pred = network(X)
            loss = loss_calc(y_pred, Y)
            loss_val = loss.detach().numpy()
            sum_acc += (y_pred.max(dim=1)[1] == Y).float().sum().detach().numpy()
            loss.backward()
            optimizer.step()

        train_acc = sum_acc/n
        print('Training accuracy of epoch {}: {}'.format(i+1, train_acc))
        scheduler.step()

        if ((i+1)%5==0):
            test_acc = 0
            testcount = 0
            for i in range(len(test_data_x)):
                testcount += len(test_data_x[i])
                X = torch.from_numpy(test_data_x[i]).float().to(device)
                Y = torch.from_numpy(test_data_y[i]).float().to(device)
                y_pred = network(X)
                test_acc += (y_pred.max(dim=1)[1] == Y).float().sum().detach().numpy()

            test_acc = test_acc / testcount
            print('Test accuracy: {}'.format(test_acc))
            if test_acc >= max_acc:
                max_acc = test_acc
                max_count = 0
                print('Saving network parameters...')
                save_params(network, savename)
            else:
                max_count += 1

            if max_count >= 5:
                break


def save_params(model, name, path='models/'):
    model_path = '{}{}'.format(path, name)
    torch.save(model.state_dict(), model_path)

#convert_data()
#convert_data(buffers=True)
#make_batches()
#main(epochs=1000, eplen=10, savename='EmbedNet10')
# main(epochs=1000, eplen=25, savename='EmbedNet25')
# main(epochs=1000, eplen=50, savename='EmbedNet50')
# main(epochs=1000, eplen=100, savename='EmbedNet100')
#main(epochs=1000, eplen=10, output_dim=10, savename='EmbedNet10_mean_mean')
#main(epochs=1000, eplen=10, output_dim=10, savename='EmbedNet10_max_mean')
#main(epochs=1000, eplen=10, output_dim=10, savename='EmbedNet10_max_max')

#XGB
#main(epochs=1000, eplen=10, output_dim=10, buffer=False, savename='EmbedNet10_mean_max_newtest')
#RF
main(epochs=1000, input_dim=6, eplen=10, output_dim=10, buffer=False, savename='EmbedNet10_mean_max_newtest_rf',
     model='rf')