import os
import arff
import random
import numpy as np
from time import time
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

def log_barrier(yhat, k, gamma):
	yhat_v = yhat.view(-1)
	kmin = 1.0
	kmax = k
	eps = 1e-6
	yhat_b, _ = torch.min(yhat_v, axis=0)
	yhat_b = yhat_b.item()
	while (kmax-kmin>eps):
		temp = (kmax+kmin)/2
		pt = 1 / (gamma * (yhat_v - yhat_b) + temp)
		pt_sum = torch.sum(pt)
		if (abs(pt_sum-1)<eps):
			return pt
		if (pt_sum > 1):
			kmin = temp
		elif (pt_sum < 1):
			kmax = temp
	return pt

def greedy(yhat, k, gamma):
	yhat_v = yhat.view(-1)
	yhat_b, yb = torch.min(yhat_v, axis=0);
	yb = yb.item()
	pt = torch.zeros([yhat_v.shape[0]])
	pt[yb] = 1
	return pt

def L2_FTRL(yhat, k, gamma):
	yhat_v = yhat.view(-1)
	yhat_b, _ = torch.min(yhat_v, axis=0)
	kmin = 0
	kmax = gamma+1
	eps = 1e-5
	eps_gap = 1e-6
	
	while (True):
		temp = (kmax+kmin)/2
		pt = temp - gamma * (yhat_v-yhat_b)
		pt = F.relu(pt)
		pt_sum = torch.sum(pt)
		if (abs(pt_sum-1)<eps):
			return pt
		if (pt_sum > 1):
			kmax = temp
		elif (pt_sum < 1):
			kmin = temp
	return pt

def invent_graph(yhat, k, gamma):
	yhat_v = yhat.view(-1)
	yhat_b, yb = torch.min(yhat_v, axis=0)
	yb = yb.item()
	pt = torch.zeros([k])
	remain = 1
	acc = 0
	for i in range(k-1, 0, -1):
		if (i == yb):
			pt[i] = 1 - acc
			# print("pt: ", pt)
			return pt
		else:
			pt[i] = max(1/(1 + gamma * (yhat_v[i] -yhat_b)) - acc, 0)
			acc = max(1/(1 + gamma * (yhat_v[i] -yhat_b)), acc)
	if (torch.sum(pt) < 1):
		pt[0] = 1-torch.sum(pt).item();
	return pt

def undirected_graph(yhat, k, alpha, gamma, graph):
	yhat = yhat.view(-1)
	yhat_b, yb = torch.min(yhat, axis=0)
	yb = yb.item()

	gap = yhat - yhat_b
	is_cover = defaultdict(int)
	greedy_seq = torch.argsort(yhat)
	pt = torch.zeros([k])
	select = []
	for i in range(k):
		j = greedy_seq[i].item()
		if is_cover[j] == 0:
			select.append(j)
			for l in range(k):
				if graph[j, l] == 1:
					is_cover[l] = 1
	m = len(select)
	for i in select:
		pt[i] = 1.0 / (m + gamma * gap[i].item())

	if torch.sum(pt) < 1:
		pt[yb] += 1 - torch.sum(pt)
	return pt

def robs_cops_graph(yhat, k, gamma):
	yhat_v = yhat.view(-1)
	yhat_b, yb = torch.topk(yhat_v, k=2, largest=False)

	pt = torch.zeros([k])
	yb1 = yb[0].item()
	yb2 = yb[1].item()

	gap = gamma * (yhat_b[1] - yhat_b[0])
	gap = gap.item()

	pt[yb2] = 1 / (2 + gap)
	pt[yb1] = 1 - pt[yb2]
	return pt

def closed_form(yhat, k, alpha, gamma, graph, graph_T):
	if (graph_T == 'bandit'):
		pt = log_barrier(yhat, k, gamma)
	elif (graph_T == 'full_info'):
		pt = greedy(yhat, k, gamma)
	elif (graph_T == 'inventory'):
		pt = invent_graph(yhat, k, gamma)
	elif (graph_T == 'robs_cops'):
		pt = robs_cops_graph(yhat, k, gamma)
	elif (graph_T == 'undirected'):
		pt = undirected_graph(yhat, k, alpha, gamma, graph)
	return pt

def squarecb(yhat, k, gamma, graph_T):
	yhat_b, yb = torch.min(yhat, axis=1)
	yb = yb.item()

	pt = 1 / (gamma * (yhat - yhat_b) + k)
	pt = pt.view(-1)
	pt[yb] = 1 - torch.sum(pt) + 1.0 / k
	return pt
