import os
import dsp
import math
import arff

import random
import cvxpy as cp
import numpy as np
from tqdm.auto import tqdm
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from utils import *
from models import *
from method_list import *

device = 'cpu'
dim = 100

h = 0.25
pnt = 1

def work(seed, X, y, graph, graph_T, model_name, train_scheme, loss_type='square', method = 'cvx', b_size = 64, c=10, k=100, regret_ref = None) :
	X_train, arm_train, y_train = [], [], []
	if model_name == 'linear':
		model = Linear_Cls(X.shape[1], y.shape[1]).to(device)
	elif model_name == 'inventory':
		n= X.shape[0]
		
		model = Inventory(X.shape[1], k, h, pnt)
	else:
		model = MLP_Cls(X.shape[1], dim, y.shape[1]).to(device)

	n= X.shape[0]
	alpha = 1;
	print("n: \tc: ", n, c)

	tot_regret = 0
	regret_list = []
	tf = time()
	gamma = np.zeros((n,));

	torch.manual_seed(seed)

	for i in tqdm(range(n), ncols=10):
		if method == 'closed-form':
			k = int( 4*((i+1) ** (1/2) ))+1;
			gamma[i] = math.sqrt(i * alpha) * c
		elif method == 'log-barrier':
			k = int( 4*((i+1) ** (1/3) ))+1;
			gamma[i] = math.sqrt(i * k) * c
		

		fixed_arm_set = np.arange(0,1+1/(k-1), 1/(k-1));

		model.eval()
		X_train_single = []
		y_train_single = []
		arm_train_single = []
		with torch.no_grad():
			x = torch.from_numpy(X[i, :]).view(1, -1).to(device)
			model.N = k;
			pred = model(x)
			yhat = pred
			if method == 'log-barrier':
				pt = log_barrier(yhat, k, gamma[i])
			elif method == 'greedy':
				pt = greedy(yhat, k, gamma[i])
			elif method == 'closed-form':
				pt = closed_form(yhat, k, gamma[i], graph_T)


			assert abs(torch.sum(pt).item() - 1) < 1e-4
			for p in range (k):
				assert pt[p] >= 0

			pred_arm = torch.multinomial(pt, num_samples=1).item()

			loss_ins = inventory_loss(h, pnt, pred_arm/(k-1),y[i])
			tot_regret += loss_ins[0];


			X_train_single.append(x)
			feedback_arm = torch.zeros(k+1)

			feedback_arm[0:pred_arm+1] = 1

			arm_train_single.append(feedback_arm.view(1, -1))

			y_train = torch.unsqueeze(torch.tensor(inventory_loss(h, pnt, fixed_arm_set, y[i])), 0);

			y_train_single = []
			y_train_single.append(y_train)



		if (i+1) % train_iter == 0:
			train_onestep_OGD(model, X_train_single, arm_train_single, y_train_single, lr=0.01)

		regret_list.append(tot_regret)

		if ( (regret_ref != None) and (tot_regret > regret_ref[-1]+0.1)):
			print("\n Already worse than the opt")
			break
		if (i+1) % 2000 == 0:
			print("Time:{:.2f}\tIters:{}\tLoss:{}".format(time()-tf, i+1, tot_regret))
			tf = time()

	return regret_list


dataset = 'inventory'

graph_T = ['inventory']

method = ['log-barrier','closed-form']

model_name = 'inventory'

train_scheme = 'ogd'

N_seed = 2;
seed_list = range(0, N_seed);
tag = 'fixed'
if __name__ == "__main__":
	print("Train: ", train_scheme)
	train_iter = 1;
	if dataset == 'rcv1':
		file_path = './rcv1/data_balance_50k.gt'
		X, y = load_rcv1(file_path)
		num_sub_class = 10
	elif dataset == 'inventory':
		file_path = "./inventory/data_inventory.gt"
		X,y = load_inventory(file_path)
		num_sub_class = 10000

	## Adaptive discretization

	c_list = [0.125,0.25,0.5,1,2,3,4];
	X = X.astype(np.float32)
	y = y.astype(np.float32)
	n = 10000
	X_s = X[0:n]
	y_s = y[0:n];
	for mth in method:
		if (mth == 'log-barrier'):
			c_list = [4] # after c_list search
		else:
			c_list = [3] # after c_list search
		for graph_type in graph_T:
			seed_ind = 0;
			for seed in seed_list:
				print("seed: ", seed)
				random.seed(1)
				np.random.seed(1)
				torch.manual_seed(1)
				print("=========================================================")
				print("Dataset:{}\tGraph:{}\tMethod:{}".format(dataset, graph_type, mth))
				os.makedirs("./res_{}_{}".format(dataset, graph_type), exist_ok=True)
				print(X_s.shape, y_s.shape)

				k = 'adaptive'
				graph = get_graph(graph_type, 3)
				
				## Pick different parameterization, pick the best one
				reg_min = 999999999;
				attm = 0;
				for c in c_list: 
					if (attm > 0):
						regret_list_temp = work(seed, X_s, y_s, graph, graph_type, model_name, train_scheme, loss_type = 'l1', method = mth, b_size=256, c=c, k=k, regret_ref=regret_ref)
					else:
						regret_list_temp = work(seed, X_s, y_s, graph, graph_type, model_name, train_scheme, loss_type = 'l1', method = mth, b_size=256, c=c, k=k, regret_ref=None)
					if (regret_list_temp[-1]<reg_min):
						reg_min = regret_list_temp[-1]
						regret_ref = regret_list_temp 
						np.save('./res_{}_{}/{}_{}_{}_{}_{}_{}_{}_{}_{}.npy'.format(dataset, graph_type, k, h, pnt, mth, n, model_name, train_scheme,seed,tag), np.array(regret_list_temp))
					if (mth == 'greedy'):
						break
					attm += 1
				seed_ind += 1
