import os
import math
import arff
from copy import deepcopy

import vowpalwabbit
import random
import numpy as np
from tqdm.auto import tqdm
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from utils import load_rcv1, load_rcv1_full, get_graph

def to_vw_graph(graph):
	k = graph.shape[0]
	out_string = "shared |"
	return out_string

def to_vw_context(context, n_actions, cb_label=None):
	if cb_label is not None:
		action, prob, cost = cb_label
	d = context.shape[0]
	out_string = 'shared |S '
	for i in range(d-1):
		out_string += 'ft{}:{} '.format(i+1, context[i])
	out_string += 'ft{}:{}\n'.format(d, context[d-1])
	for i in range(n_actions):
		if cb_label is not None and i == action:
			out_string += "{}:{}:{} ".format(i, cost, prob)
		out_string += '|A action={} \n'.format(i)
	return out_string[:-1]

def work(graph, out_string, c):
	vw = vowpalwabbit.Workspace("--cb_explore_adf -q SA --quiet --squarecb --coin --gamma_scale {}".format(c))
	regret_list = []
	log_lines = []
	tot_regret = 0
	n = X.shape[0]
	n = min(n, 10000)
	for i in tqdm(range(n)):
		vw_input = to_vw_context(X[i, :], n_actions)
		pmf = vw.predict(vw_input)
		pmf = torch.FloatTensor(pmf)
		pmf = pmf / torch.sum(pmf)
		pred_arm = torch.multinomial(pmf, num_samples=1).item()
		prob = pmf[pred_arm].item()
		if y[i, pred_arm] == 1:
			cost = 0
		else:
			cost = 1
		
		tot_regret += cost
		regret_list.append(tot_regret)
		log_lines.append(vw_input + '\n')
		for j in range(n_actions):
			if graph[pred_arm, j]:
				cost = 1 - y[i, j]
				prob = pmf[j].item()
				action = j
				vw_train_input = to_vw_context(X[i, :], n_actions, (action, prob, cost))
				vw_train_format = vw.parse(vw_train_input, vowpalwabbit.LabelType.CONTEXTUAL_BANDIT)
				vw.learn(vw_train_format)
		
		if (i+1) % 5000 == 0:
			print("Regret:{}".format(tot_regret))
			# np.save('./res_vw_{}_random/bandit_loss.npy'.format(dataset), np.array(tot_regret))
		global reg_min
		if tot_regret > reg_min:
			return regret_list

	print(tot_regret)
	# with open("squarecb.txt", "w") as f:
	# 	print('\n'.join(log_lines), file=f)
	# vw.save('./squarecb.model')
	return regret_list


c_list = [32, 64, 128]
seed_list = [1, 2, 3, 4, 5]
dataset = 'rcv1_10full'
graph_type = 'random'

if __name__ == "__main__":
	if dataset == 'rcv1_10':
		file_path = './rcv1/data_pca_10class.gt'
		X, y = load_rcv1(file_path)
	elif dataset == 'rcv1_50':
		file_path = './rcv1/data_pca_50class.gt'
		X, y = load_rcv1(file_path)
	elif dataset == 'rcv1_10full':
		file_path = './rcv1/data_full_10class.gt'
		X, y = load_rcv1_full(file_path)
	elif dataset == 'rcv1_20full':
		file_path = './rcv1/data_full_20class.gt'
		X, y = load_rcv1_full(file_path)
	elif dataset == 'rcv1_50full':
		file_path = './rcv1/data_full_50class.gt'
		X, y = load_rcv1_full(file_path)

	X = X.astype(np.float32)
	y = y.astype(np.float32)
	n_actions = y.shape[1]
	res_list = []

	for seed in seed_list:
		random.seed(seed)
		np.random.seed(seed)
		torch.manual_seed(seed)
				
		permutation = np.random.permutation(X.shape[0])
		X_s = X[permutation]
		y_s = y[permutation]
		graph = get_graph(graph_type, n_actions)
		out_string = to_vw_graph(graph)
		global reg_min
		reg_min = 1e9

		for c in c_list:
			print("Seed:{}\tC:{}".format(seed, c))
			regret_list = work(graph, out_string, c)
			if regret_list[-1] < reg_min:
				reg_min = regret_list[-1]
				best_res = deepcopy(regret_list)

		best_res = np.array(best_res).reshape(1, -1)
		res_list.append(best_res)
		print(reg_min)
		

	res_list = np.concatenate(res_list, axis=0)
	print(res_list.shape)
	os.makedirs("./new_res_vw_{}_{}".format(dataset, graph_type), exist_ok=True)
	np.save('./new_res_vw_{}_{}/squarecb_loss.npy'.format(dataset, graph_type), res_list)
	
