import os
import arff
import random
import numpy as np
from time import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

def log_barrier(yhat, k, gamma):
	yhat_v = yhat.view(-1)
	kmin = 1.0
	kmax = k
	eps = 1e-6
	yhat_b, _ = torch.min(yhat_v, axis=0)
	yhat_b = yhat_b.item()
	while (kmax-kmin>eps):
		temp = (kmax+kmin)/2
		pt = 1 / (gamma * (yhat_v - yhat_b) + temp)
		pt_sum = torch.sum(pt)
		if (abs(pt_sum-1)<eps):
			return pt
		if (pt_sum > 1):
			kmin = temp
		elif (pt_sum < 1):
			kmax = temp
	return pt

def greedy(yhat, k, gamma):
	yhat_v = yhat.view(-1)
	yhat_b, yb = torch.min(yhat_v, axis=0);
	yb = yb.item()
	pt = torch.zeros([yhat_v.shape[0]])
	pt[yb] = 1
	return pt

def L2_FTRL(yhat, k, gamma):
	yhat_v = yhat.view(-1)
	yhat_b, _ = torch.min(yhat_v, axis=0)
	kmin = 0
	kmax = gamma+1
	eps = 1e-5
	eps_gap = 1e-6
	
	while (True):
		temp = (kmax+kmin)/2
		pt = temp - gamma * (yhat_v-yhat_b)
		pt = F.relu(pt)
		pt_sum = torch.sum(pt)
		if (abs(pt_sum-1)<eps):
			return pt
		if (pt_sum > 1):
			kmax = temp
		elif (pt_sum < 1):
			kmin = temp
	return pt

def invent_graph(yhat, k, gamma):
	yhat_v = yhat.view(-1)
	yhat_b, yb = torch.min(yhat_v, axis=0)
	yb = yb.item()
	pt = torch.zeros([k])
	remain = 1
	acc = 0
	for i in range(k-1, 0, -1):
		if (i == yb):
			pt[i] = 1 - acc
			return pt
		else:
			pt[i] = max(1/(1 + gamma * (yhat_v[i] -yhat_b)) - acc, 0)
			acc = max(1/(1 + gamma * (yhat_v[i] -yhat_b)), acc)
	if (torch.sum(pt) < 1):
		pt[0] = 1-torch.sum(pt).item();
	return pt

def robs_cops_graph(yhat, k, gamma):
	yhat_v = yhat.view(-1)
	yhat_b, yb = torch.topk(yhat_v, k=2, largest=False)

	pt = torch.zeros([k])
	yb1 = yb[0].item()
	yb2 = yb[1].item()

	gap = gamma*(yhat_b[1] - yhat_b[0])
	gap = gap.item()

	pt[yb1] = 2 / (2 - gap + np.sqrt(4+(gap ** 2)))
	pt[yb2] = 1-pt[yb1]

	return pt

## Any closed-form which gives nice bound to the minimax value
def closed_form(yhat, k, gamma, graph_T):
	if (graph_T == 'bandit'):
		pt = log_barrier(yhat, k, gamma)
	elif (graph_T == 'full_info'):
		pt = greedy(yhat, k, gamma)
	elif (graph_T == 'inventory'):
		pt = invent_graph(yhat, k, gamma)
	elif (graph_T == 'robs_cops'):
		pt = robs_cops_graph(yhat, k, gamma)
	return pt

def squarecb(yhat, k, gamma, graph_T):
	yhat_b, yb = torch.min(yhat, axis=1)
	yb = yb.item()

	pt = 1 / (gamma * (yhat - yhat_b) + k)
	pt = pt.view(-1)
	pt[yb] = 1 - torch.sum(pt) + 1.0 / k
	return pt
	