import os
import dsp
import math
import arff

import random
import cvxpy as cp
import numpy as np
from tqdm.auto import tqdm
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from utils import *
from models import *
from method_list import *

device = 'cpu'
dim = 100

h = 0.25
pnt = 1

def work(seed, X, y, graph, graph_T, model_name, train_scheme, loss_type='square', method = 'cvx', b_size = 64, c=10, k=100, regret_ref = None) :
	X_train, arm_train, y_train = [], [], []
	if model_name == 'linear':
		model = Linear_Cls(X.shape[1], y.shape[1]).to(device)
	elif model_name == 'inventory':
		n= X.shape[0]
		fixed_arm_set = np.arange(0,1+1/(k-1), 1/(k-1));
		print("len_arm: ", len(fixed_arm_set))
		model = Inventory(X.shape[1], k, h, pnt)

	n= X.shape[0]
	alpha = 1;
	print("n: \tc: ", n, c)

	tot_regret = 0
	regret_list = []
	tf = time()
	gamma = np.zeros((n,));
	assert graph.shape[0] == graph.shape[1] == k

	torch.manual_seed(seed)
	
	for i in tqdm(range(n), ncols=10):
		gamma[i] = math.sqrt(i * k) * c
		if method == 'closed-form':
	 		gamma[i] = math.sqrt(i * alpha) * c

		model.eval()
		X_train_single = []
		y_train_single = []
		arm_train_single = []
		with torch.no_grad():
			x = torch.from_numpy(X[i, :]).view(1, -1).to(device)

			pred = model(x)
			yhat = pred
			# print("method: ", method)
			if method == 'log-barrier':
				pt = log_barrier(yhat, k, gamma[i])
			elif method == 'greedy':
				pt = greedy(yhat, k, gamma[i])
			elif method == 'closed-form':
				# print("iinn")
				# print("graph_T: ", graph_T)
				pt = closed_form(yhat, k, gamma[i], graph_T)
				# print("iinn pt: ", pt)

			assert abs(torch.sum(pt).item() - 1) < 1e-4
			for p in range (k):
				assert pt[p] >= 0

			pred_arm = torch.multinomial(pt, num_samples=1).item()

			loss_ins = inventory_loss(h, pnt, pred_arm/(k-1),y[i])
			tot_regret += loss_ins[0];


			X_train_single.append(x)
			feedback_arm = torch.zeros(k+1)

			feedback_arm[0:pred_arm+1] = 1

			arm_train_single.append(feedback_arm.view(1, -1))

			y_train = torch.unsqueeze(torch.tensor(inventory_loss(h, pnt, fixed_arm_set, y[i])), 0);

			y_train_single = []
			y_train_single.append(y_train)


		if (i+1) % train_iter == 0:
			train_onestep_OGD(model, X_train_single, arm_train_single, y_train_single, lr=0.01)

		regret_list.append(tot_regret)

		if ( (regret_ref != None) and (tot_regret > regret_ref[-1]+0.1)):
			print("Already worse than the opt")
			break
		if (i+1) % 2000 == 0:
			print("Time:{:.2f}\tIters:{}\tLoss:{}".format(time()-tf, i+1, tot_regret))
			tf = time()
		
	# print(tot_regret)
	return regret_list


dataset = 'inventory'
graph_T = ['inventory']
method = ['log-barrier','closed-form']

model_name = 'inventory'
train_scheme = 'ogd'
N_seed = 8;
seed_list = range(0, N_seed);
tag = 'fixed'
if __name__ == "__main__":
	print("Train: ", train_scheme)
	train_iter = 1;
	if dataset == 'rcv1':
		file_path = './rcv1/data_balance_50k.gt'
		X, y = load_rcv1(file_path)
		num_sub_class = 10
	elif dataset == 'rcv1_full':
		file_path = './rcv1/data_full_train.gt'
		X, y = load_rcv1(file_path)
		num_sub_class = 103
	elif dataset == 'inventory':
		file_path = "./inventory/data_inventory.gt"
		X,y = load_inventory(file_path)
		num_sub_class = 10000

	c_list = [0.5,1,2,3,4];
	sc = [1,3,5]
	X = X.astype(np.float32)
	y = y.astype(np.float32)
	n = 10000
	X_s = X[0:n]
	y_s = y[0:n];
	for mth in method:
		if (mth == 'log-barrier'):
			c_list = [4] # after c_list search
		else:
			c_list = [3] # after c_list search

		for graph_type in graph_T:

			seed_ind = 0;
			for act_scale in sc:
				for seed in seed_list:
					print("seed: ", seed)
					random.seed(1)
					np.random.seed(1)
					torch.manual_seed(1)
					print("=========================================================")
					print("Dataset:{}\tGraph:{}\tMethod:{}".format(dataset, graph_type, mth))
					os.makedirs("./res_{}_{}".format(dataset, graph_type), exist_ok=True)
					print(X_s.shape, y_s.shape)
					
					k = act_scale*int(n/100)+1
					graph = get_graph(graph_type, k)
					
					## Pick different parameterization, pick the best one
					reg_min = 999999999;
					attm = 0;
					for c in c_list: 
						if (attm > 0):
							regret_list_temp = work(seed, X_s, y_s, graph, graph_type, model_name, train_scheme, loss_type = 'l1', method = mth, b_size=256, c=c, k=k, regret_ref=regret_ref)
						else:
							regret_list_temp = work(seed, X_s, y_s, graph, graph_type, model_name, train_scheme, loss_type = 'l1', method = mth, b_size=256, c=c, k=k, regret_ref=None)
						if (regret_list_temp[-1]<reg_min):
							reg_min = regret_list_temp[-1]
							regret_ref = regret_list_temp 
							np.save('./res_{}_{}/{}_{}_{}_{}_{}_{}_{}_{}_{}.npy'.format(dataset, graph_type, k, h, pnt, mth, n, model_name, train_scheme,seed,tag), np.array(regret_list_temp))
						if (mth == 'greedy'):
							break
						attm += 1
					seed_ind += 1
