import os
import math
import arff
from copy import deepcopy

import random
import cvxpy as cp
import numpy as np
from tqdm.auto import tqdm
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from utils import load_rcv1, load_rcv1_full, get_graph, calc_ind, calc_ind_rand
from models import *
from method_list import *

os.environ['CUDA_VISIBLE_DEVICES'] = '2'
device = 'cpu'

def work(X, y, graph, graph_T, model_name, train_scheme, loss_type='l1', method = 'closed-form', c=10):
	X_train, arm_train, y_train = [], [], []
	model = Linear_Cls(X.shape[1], y.shape[1]).to(device)
	
	n, k = X.shape[0], y.shape[1]
	alpha = len(calc_ind(graph))
	print("alpha: ", alpha)
	print("n: \tc: ", n, c)

	tot_regret = 0
	regret_list = []
	tf = time()
	gamma = np.zeros((n,));
	assert graph.shape[0] == graph.shape[1] == k
	for i in tqdm(range(n)):
		gamma[i] = math.sqrt(i * k) * c
		if method == 'closed-form':
	 		gamma[i] = math.sqrt(i * alpha) * c
		
		model.eval()
		X_train_single = []
		y_train_single = []
		arm_train_single = []
		with torch.no_grad():
			x = torch.from_numpy(X[i, :]).view(1, -1).to(device)
			lbl = torch.from_numpy(y[i, :]).view(1, -1).to(device)
			pred = model(x)

			if loss_type == 'square':
				yhat = (pred - 1) ** 2
			elif loss_type == 'l1':
				yhat = - pred
			
			if method == 'log-barrier':
				pt = log_barrier(yhat, k, gamma[i])
			elif method == 'greedy':
				pt = greedy(yhat, k, gamma[i])
			elif method == 'closed-form':
				pt = closed_form(yhat, k, alpha, gamma[i], graph, graph_T)
			elif method == 'squarecb':
				pt = squarecb(yhat, k, gamma[i], graph)

			assert abs(torch.sum(pt).item() - 1) < 1e-4
			for p in range (k):
				assert pt[p] >= 0

			pred_arm = torch.multinomial(pt, num_samples=1).item()
			if not y[i, pred_arm]:
				tot_regret += 1

			X_train.append(x)
			X_train_single.append(x)
			feedback_arm = torch.zeros(k)
			for l in range(k):
				if graph[pred_arm, l].item() > 0:
					feedback_arm[l] = 1

			arm_train.append(feedback_arm.view(1, -1))
			arm_train_single.append(feedback_arm.view(1, -1))

			y_train.append(lbl)
			y_train_single.append(lbl)
		train_onestep_OGD(model, X_train_single, arm_train_single, y_train_single, lr=lr, device=device)
			
		regret_list.append(tot_regret)
		if (i+1) % 10000 == 0:
			print("Time:{:.2f}\tIters:{}\tReg:{}".format(time()-tf, i+1, tot_regret))
			tf = time()

	print(tot_regret)
	return regret_list


dataset = 'rcv1_50full'
edge_prob = 0.75
if dataset.startswith('rcv1'):
	lr = 2.0
else:
	lr = 0.005

graph_T = ['bandit', 'full_info', 'robs_cops']
method = ['closed-form']

model_name = 'linear'
train_scheme = 'ogd'
seed_list = [1, 2, 3, 4, 5]

if __name__ == "__main__":
	print("Train: ", train_scheme)
	print("Device: ", device)
	print("Lr:{:.2f}".format(lr))
	train_iter = 1
	if dataset == 'rcv1_10':
		file_path = './rcv1/data_pca_10class.gt'
		X, y = load_rcv1(file_path)
	elif dataset == 'rcv1_50':
		file_path = './rcv1/data_pca_50class.gt'
		X, y = load_rcv1(file_path)
	elif dataset == 'rcv1_10full':
		file_path = './rcv1/data_full_10class.gt'
		X, y = load_rcv1_full(file_path)
	elif dataset == 'rcv1_20full':
		file_path = './rcv1/data_full_20class.gt'
		X, y = load_rcv1_full(file_path)
	elif dataset == 'rcv1_50full':
		file_path = './rcv1/data_full_50class.gt'
		X, y = load_rcv1_full(file_path)
	
	print(X.shape, y.shape)
	num_sub_class = y.shape[1]
	c_list = [32, 64, 128]
	X = X.astype(np.float32)
	y = y.astype(np.float32)
	for mth in method:
		for graph_type in graph_T:
			res_list = []
			print("=========================================================")
			print("Dataset:{}\tGraph:{}\tMethod:{}".format(dataset, graph_type, mth))
			os.makedirs("./res_{}_{}".format(dataset, graph_type), exist_ok=True)
			for seed in seed_list:
				random.seed(seed)
				np.random.seed(seed)
				torch.manual_seed(seed)
				
				permutation = np.random.permutation(X.shape[0])
				X_s = X[permutation]
				y_s = y[permutation]
				graph = get_graph(graph_type, y.shape[1], edge_prob)
				
				reg_min = 1e9
				for c in c_list: 
					regret_list = work(X_s, y_s, graph, graph_type, model_name, train_scheme, loss_type = 'l1', method = mth, c=c)
					if regret_list[-1] < reg_min:
						reg_min = regret_list[-1]
						best_res = deepcopy(regret_list)
					if mth == 'greedy':
						break
				assert best_res[-1] == reg_min
				res_list.append(np.array(best_res).reshape(1, -1))
			res_list = np.concatenate(res_list, axis=0)
			print(res_list.shape)
			np.save('./res_{}_{}/{}_loss.npy'.format(dataset, graph_type, mth), res_list)

