import os
import arff
import random
import numpy as np
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from sparselinear import SparseLinear
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

class Sparse_Linear_Cls(nn.Module):
	def __init__(self, input_dim, num_label):
		super(Sparse_Linear_Cls, self).__init__()
		self.fc1 = SparseLinear(input_dim, num_label)

	def forward(self, x):
		return torch.sigmoid(self.fc1(x))

class Linear_Cls(nn.Module):
	def __init__(self, input_dim, num_label):
		super(Linear_Cls, self).__init__()
		self.fc1 = nn.Linear(input_dim, num_label)

	def forward(self, x):
		return torch.sigmoid(self.fc1(x))

class Inventory(nn.Module):
	def __init__(self, input_dim, out_dim, h, p):
		super(Inventory, self).__init__()
		self.fc1 = nn.Linear(input_dim, 1, bias = True)
		self.h = h
		self.p = p
		self.N = out_dim

	def forward(self, x):
		# pass the input tensor through the linear layer
		r = self.fc1(x)
		
		# compute the final output using the specified formula
		i = torch.arange(self.N).to(r.device).unsqueeze(0)
		max1 = self.h * torch.nn.functional.relu(i/(self.N) - r)
		max2 = self.p * torch.nn.functional.relu(r - i/(self.N))
		output = max1 + max2

		return output


# Two-layer MLP
class MLP_Cls(nn.Module):
	def __init__(self, input_dim, dim, num_label):
		super(MLP_Cls, self).__init__()
		self.fc1 = nn.Linear(input_dim, dim)
		self.fc2 = nn.Linear(dim, num_label)
		self.relu = nn.ReLU(True)

	def forward(self, x):
		return torch.sigmoid(self.fc2(self.relu(self.fc1(x))))

def inventory_loss(h,pnt,inv,d):
	return h*np.maximum(inv-d,0) + pnt*np.maximum(d-inv,0)

def train_onestep_OGD(model, X, arm, y,loss_type='square', lr=2.0, batch_size=64, device='cpu'):
	model.train()
	X = torch.cat(X).float()
	arm = torch.cat(arm).int()
	y = torch.cat(y).float()

	k = y.shape[1]

	optimizer = optim.Adam(model.parameters(), lr=lr)
	dataset = TensorDataset(X, arm, y)
	dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

	# Using Square-loss
	if (loss_type == 'square'):
		loss_fn = nn.MSELoss().to(device)
	else:
		loss_fn = nn.BCELoss().to(device)
	num = len(dataloader)

	batch_loss = 0.0
	for x, arms, y in dataloader:
		x, y = x.to(device), y.to(device)
		pred = model(x)
		pred_c, pred_l = [], []
		bs = x.shape[0]
		for j in range(bs):
			for l in range(k):
				if arms[j, l].item() == 1:
					pred_c.append(pred[j, l].view(-1))
					pred_l.append(y[j, l].view(-1))

		pred_c, pred_l = torch.cat(pred_c), torch.cat(pred_l)
		loss = loss_fn(pred_c, pred_l)
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		batch_loss += loss.item()
		
	return model

def train_cls_batch(model, X, arm, y, num_epochs=20, lr=0.001, batch_size=64, device='cpu'):
	model.train()
	X = torch.cat(X).float()
	arm = torch.cat(arm).int()
	y = torch.cat(y).float()
	k = y.shape[1]

	optimizer = optim.Adam(model.parameters(), lr=lr)
	dataset = TensorDataset(X, arm, y)
	dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
	loss_fn = nn.BCELoss().to(device)
	num = len(dataloader)

	for i in range(num_epochs):
		batch_loss = 0.0
		for x, arms, y in dataloader:
			x, y = x.to(device), y.to(device)
			pred = model(x)
			pred_c, pred_l = [], []
			bs = x.shape[0]
			for j in range(bs):
				for l in range(k):
					if arms[j, l].item() == 1:
						pred_c.append(pred[j, l].view(-1))
						pred_l.append(y[j, l].view(-1))

			pred_c, pred_l = torch.cat(pred_c), torch.cat(pred_l)
			loss = loss_fn(pred_c, pred_l)
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

			batch_loss += loss.item()
		if batch_loss / num <= 1e-3:
			return batch_loss / num

	return batch_loss / num

def train_cls_MC_batch(model, X, arm, y, num_epochs=20, lr=0.001, batch_size=64, device='cpu'):
	model.train()
	X = torch.cat(X).float()
	arm = torch.cat(arm).int()
	y = torch.cat(y).float()
	k = y.shape[1]

	optimizer = optim.Adam(model.parameters(), lr=lr)
	dataset = TensorDataset(X, arm, y)
	dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
	loss_fn = nn.MSELoss().to(device)
	num = len(dataloader)

	for i in range(num_epochs):
		batch_loss = 0.0
		for x, arms, y in dataloader:
			x, y = x.to(device), y.to(device)
			pred = model(x)
			pred_c, pred_l = [], []
			bs = x.shape[0]
			for j in range(bs):
				for l in range(k):
					if arms[j, l].item() == 1:
						pred_c.append(pred[j, l].view(-1))
						pred_l.append(y[j, l].view(-1))

			pred_c, pred_l = torch.cat(pred_c), torch.cat(pred_l)
			loss = loss_fn(pred_c, pred_l)
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

			batch_loss += loss.item()
		if batch_loss / num <= 1e-3:
			return batch_loss / num

	return batch_loss / num