import os
import math
import arff
from copy import deepcopy

import vowpalwabbit
import random
import numpy as np
from tqdm.auto import tqdm
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from utils import load_rcv1, load_rcv1_full, get_graph, calc_ind

def to_vw_graph(graph):
	k = graph.shape[0]
	out_string = "shared graph "
	for i in range(k):
		for j in range(k):
			v = int(graph[i, j])
			out_string += '{},{},{} '.format(i, j, v)
	out_string += '|'
	return out_string

def to_vw_context(context, out_string, n_actions, cb_label=None):
	d = context.shape[0]
	out_string += 'S '
	for i in range(d-1):
		out_string += 'ft{}:{} '.format(i+1, context[i])
	out_string += 'ft{}:{}\n'.format(d, context[d-1])
	for action in range(n_actions):
		if cb_label is not None and action in cb_label['observe_actions']:
			out_string += "{}:{}:{} ".format(cb_label['selected'], 
						  cb_label['cost'][action], cb_label['prob'][cb_label['selected']])
		out_string += '|A action={} \n'.format(action)
	return out_string[:-1]

def work(graph, out_string, c):
	n = X.shape[0]
	n = min(n, 10000)
	alpha = len(calc_ind(graph))
	vw = vowpalwabbit.Workspace("--cb_explore_adf -q SA --quiet --graph_feedback --coin --gamma_scale {}".format(c))
	log_lines = []
	regret_list = []
	tot_regret = 0
	for i in tqdm(range(n)):
		vw_input = to_vw_context(X[i, :], out_string, n_actions)
		pmf = vw.predict(vw_input)
		
		cb_label = {}
		cb_label['cost'], cb_label['prob'] = {}, {}
		for j in range(n_actions):
			cb_label['cost'][j] = 1 - int(y[i, j])
			cb_label['prob'][j] = 1.0

		pmf = torch.FloatTensor(pmf)
		pred_arm = torch.multinomial(pmf, num_samples=1).item()
		tot_regret += cb_label['cost'][pred_arm]
		regret_list.append(tot_regret)

		cb_label['selected'] = pred_arm
		cb_label['observe_actions'] = []
		for j in range(n_actions):
			if graph[pred_arm, j]:
				cb_label['observe_actions'].append(j)
		vw_train_input = to_vw_context(X[i, :], out_string, n_actions, cb_label)
		vw_train_format = vw.parse(vw_train_input, vowpalwabbit.LabelType.CONTEXTUAL_BANDIT)
		vw.learn(vw_train_format)
		if (i+1) % 2000 == 0:
			print("Regret:{}".format(tot_regret))
			# np.save('./res_vw_{}_random/graph_loss.npy'.format(dataset), np.array(res_list))
		global reg_min
		if tot_regret > reg_min:
			return regret_list

	print(tot_regret)
	return regret_list

dataset = 'rcv1_10full'
graph_type = 'random'
print(dataset, graph_type)
c_list = [32, 64, 128]
seed_list = [1, 2, 3, 4, 5]

if __name__ == "__main__":
	if dataset == 'rcv1_10':
		file_path = './rcv1/data_pca_10class.gt'
		X, y = load_rcv1(file_path)
	elif dataset == 'rcv1_50':
		file_path = './rcv1/data_pca_50class.gt'
		X, y = load_rcv1(file_path)
	elif dataset == 'rcv1_10full':
		file_path = './rcv1/data_full_10class.gt'
		X, y = load_rcv1_full(file_path)
	elif dataset == 'rcv1_20full':
		file_path = './rcv1/data_full_20class.gt'
		X, y = load_rcv1_full(file_path)
	elif dataset == 'rcv1_50full':
		file_path = './rcv1/data_full_50class.gt'
		X, y = load_rcv1_full(file_path)
	
	X = X.astype(np.float32)
	y = y.astype(np.float32)
	n_actions = y.shape[1]
	res_list = []
	os.makedirs("./new_res_vw_{}_random".format(dataset), exist_ok=True)

	for seed in seed_list:
		random.seed(seed)
		np.random.seed(seed)
		torch.manual_seed(seed)
				
		permutation = np.random.permutation(X.shape[0])
		X_s = X[permutation]
		y_s = y[permutation]
		graph = get_graph(graph_type, n_actions)
		out_string = to_vw_graph(graph)
		global reg_min
		reg_min = 1e9

		for c in c_list:
			print("Seed:{}\tC:{}".format(seed, c))
			regret_list = work(graph, out_string, c)
			if regret_list[-1] < reg_min:
				reg_min = regret_list[-1]
				best_res = deepcopy(regret_list)

		best_res = np.array(best_res).reshape(1, -1)
		res_list.append(best_res)

	res_list = np.concatenate(res_list, axis=0)
	print(res_list.shape)
	np.save('./new_res_vw_{}_random/graph_loss.npy'.format(dataset), res_list)
	
