import torch
import torch.nn as nn

import numpy as np
from collections import defaultdict
#sys.path.insert(0,'../..')
#sys.path.append("../../")
import sys, time
sys.path.append("../fair_ltr/")

from frank_wolfe import compute_Moreau_grad_softsort, compute_owa
from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
from utils import AverageMeter, customdefaultdict
from models import MLP, init_weights
import copy

from owa_optimization import OWASubgradientLayer, gini_indices, gini_indices_square, MoreauOWALossLayer, owa_optim_lp_norm_cvxpy, owa_pgd_solver_wrapper, owa_fw_solver_wrapper

#import sys
#sys.path.insert(0,'./NeurIPSIntopt/Interior/')
#sys.path.insert(0,'../..')
sys.path.append('./NeurIPSIntopt/Interior/')

torch.set_printoptions(threshold=10_000)




class MultiTrainer(): 
	def __init__(
	    self,
	    train_iterator,
	    test_iterator,
	    model_params,  
	    solver,
	    args
	):

		self.use_cuda = args.use_cuda
		self.train_iterator = train_iterator
		self.test_iterator = test_iterator
		self.n_task = args.n_task
		self.trainer_name= args.trainer_name
		self.beta = args.beta
		self.build_model(model_params)
		self.optimizer = torch.optim.Adam(self.model.parameters(), args.lr)
		self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='max', factor=0.2, min_lr=1e-7, verbose=True, patience=15)
		self.epochs = 0
		self.iter = 0 
		self.loss_func =  nn.MSELoss(reduction='none')
		self.solver = solver
		self.best = -np.infty
		self.patience = 30
		self.step = 0
		self.best_model = None
		## additional evaluation 
		self.w_gini = gini_indices_square(self.n_task)
		self.w_sum = self.w_gini.mean()*torch.ones(self.n_task)
		self.add_mse = args.add_mse
		self.lamb = args.lamb

		if self.trainer_name in ['OWALPNorm', 'OWALPNormMoreauGrad']:
			self.solver_test = owa_optim_lp_norm_cvxpy(self.n_task, args.num_item,0)
		elif self.trainer_name in ['OWAPGDSubGrad', 'OWAPGDMoreauGrad']: 
			self.solver_test = owa_pgd_solver_wrapper(self.w_gini,args.num_item, 1, args.num_iter, args.gamma, args.use_subgrad)
		elif self.trainer_name == 'OWAFoldedSubGrad': 
			self.solver_test = owa_fw_solver_wrapper(self.w_gini,args.num_item, 1, args.num_iter)

	def build_model(self, model_params):
		self.model = MLP(**model_params)
		if not (self.trainer_name in ['OWAPGDSubGrad', 'OWAFoldedSubGrad', 'OWA2Stage']): 
			init_weights(self.model, 'xavier')
		print(self.model)

	def train_epoch(self):
		self.MSELoss = nn.MSELoss()
		self.epochs += 1
		batch_time = AverageMeter("Batch time")
		data_time = AverageMeter("Data time")
		cuda_time = AverageMeter("Cuda time")
		avg_loss = AverageMeter("Loss")
		avg_regret = AverageMeter("Regret")
		avg_obj = AverageMeter("Objective")
		avg_metrics = customdefaultdict(lambda k: AverageMeter("train_"+k))
		write_losses_interval = 20
		val_loss_lst, val_mae_loss, val_regret, val_obj = [],[], [], []
		val_loss_lst2, val_mae_loss2, val_regret2, val_obj2= [],[], [], []
		if self.epochs == 1: 
			eval_results = self.evaluate()
			print("Evaluating on test set: iteration {} of epoch {}: Loss:{} \t Regret: {} \t OWA Obj: {} "
				.format(self.iter, self.epochs, eval_results['mae_loss'], eval_results['regret'], eval_results['owa_obj']))
			print('='*100)

			val_loss_lst, val_mae_loss, val_regret, val_obj = [eval_results['loss_z']], [eval_results['mae_loss']], [eval_results['regret']], [eval_results['owa_obj']]
			val_loss_lst2, val_mae_loss2, val_regret2, val_obj2 = [eval_results['loss_z2']], [eval_results['mae_loss2']], [eval_results['regret2']], [eval_results['owa_obj2']]
		self.model.train()
		train_loss_lst,train_mae_loss, train_regret, train_obj = [],[], [], []

		for (i,data) in enumerate(self.train_iterator):
			self.optimizer.zero_grad()
			with torch.autograd.detect_anomaly():
				x_train, train_true_cost, x_train_opt_sol, y_train_opt_sol, z_train_opt_sol, x_train_opt_sol_sum, y_train_opt_sol_sum, z_train_opt_sol_sum = data
				if self.use_cuda:
					x_train, train_true_cost, x_train_opt_sol, y_train_opt_sol, z_train_opt_sol,x_train_opt_sol_sum, y_train_opt_sol_sum, z_train_opt_sol_sum = x_train.cuda(), train_true_cost.cuda(), x_train_opt_sol.cuda(), y_train_opt_sol.cuda(), z_train_opt_sol.cuda(), x_train_opt_sol_sum.cuda(), y_train_opt_sol_sum.cuda(), z_train_opt_sol_sum.cuda()
				pred_c = self.model(x_train).squeeze() # B x M X N if 2 stage, B x N if End2End

				if self.trainer_name == 'OWA2Stage': 
					loss = self.MSELoss(pred_c, train_true_cost)
				elif self.trainer_name in ['OWASurrogateQP','OWALP', 'OWALPNorm', 'OWAPGDSubGrad', "OWAFoldedSubGrad"]: 
					if self.trainer_name in['OWAPGDSubGrad', 'OWAFoldedSubGrad']:
						pred_c = torch.einsum("imn, i-> imn", pred_c, 1/torch.linalg.norm(pred_c.reshape(pred_c.shape[0], -1), dim=-1))
					x_pred_sol = self.solver(pred_c)
					y_pred_sol = torch.einsum("imn, in->im", train_true_cost, x_pred_sol) # B x N # compute C*x where C is true cost matrix
					loss_batch = -OWASubgradientLayer.apply(y_pred_sol, self.w_gini)
					loss = loss_batch.mean()
				elif self.trainer_name in ['OWASurrogateMoreauGrad', "OWAPGDMoreauGrad", 'OWALPNormMoreauGrad']: 
					x_pred_sol = self.solver(pred_c.squeeze()) # B x N
					y_pred_sol = torch.einsum("imn, in->im", train_true_cost, x_pred_sol) 
					loss_multi = -MoreauOWALossLayer.apply(y_pred_sol, self.w_gini, self.beta)
					loss = loss_multi.mean()
				elif self.trainer_name == 'SumSurrogateQP':
					x_pred_sol = self.solver(pred_c.squeeze()) # B x N
					y_pred_sol = torch.einsum("imn, in->im", train_true_cost, x_pred_sol) # B x N # compute C*x where C is true cost matrix
					loss = (-(torch.einsum("im, m->i", y_pred_sol, self.w_sum))).mean()
				if (self.add_mse) & (self.trainer_name in ['OWALP', 'OWALPNorm', "OWAPGDSubGrad", "OWAPGDMoreauGrad"]): 
					total_loss = loss + self.lamb* self.MSELoss(pred_c, train_true_cost)
				else: 
					total_loss = loss
				total_loss.backward()

				self.optimizer.step()
				avg_loss.update(loss.item(), x_train.size(0))

				if self.iter % write_losses_interval == 0:
					with torch.no_grad(): 
						if self.trainer_name ==  'OWA2Stage': 
							lst_x_sol, lst_y_sol, lst_z_sol = [], [],[]
							for dat in pred_c:
								dat_np = dat.detach().numpy().astype(float)
								cur_x_sol, cur_y_sol, cur_z_sol = self.solver(dat_np)
								lst_x_sol.append(cur_x_sol)
								lst_y_sol.append(cur_y_sol)
								lst_z_sol.append(cur_z_sol)
							x_pred_sol = np.stack(lst_x_sol)
							y_pred_sol = np.stack(lst_y_sol)
							z_pred_sol = np.stack(lst_z_sol)
							loss_x = np.abs(x_pred_sol - x_train_opt_sol.numpy()).mean()
							loss_y = (y_train_opt_sol - y_pred_sol).mean()
							loss_z = torch.abs(z_train_opt_sol - z_pred_sol).mean()
						else: 
							loss_x = np.abs(x_pred_sol.detach().numpy() - x_train_opt_sol.numpy()).mean()
							y_pred_sol = y_pred_sol.detach()
	
							z_pred_sol = torch.einsum("m, im-> i", self.w_gini, y_pred_sol.sort(-1).values).numpy()
							loss_y = torch.abs(y_pred_sol - y_train_opt_sol).mean()
							loss_z = torch.abs(z_train_opt_sol - z_pred_sol).mean()


						z_obj = z_pred_sol.mean()
						avg_regret.update(loss_y.item(), pred_c.size(0))
						avg_obj.update(z_obj.item(), pred_c.size(0))
						train_mae_loss.append(loss_x) 
						train_regret.append(loss_y)
						train_obj.append(z_obj)
						train_loss_lst.append(loss_z.item())

						print('*'*40)
						print("Evaluating on train set: iteration {} of epoch {}:".format(self.iter, self.epochs ))

						meters = [batch_time, data_time, avg_loss, avg_regret, avg_obj]
						meter_str = "\t".join([str(meter) for meter in meters])
						print(f"Epoch: {self.epochs}\t{meter_str}")
						print('Train: z_obj', z_obj, z_train_opt_sol.mean())

						print("SGD lr=%.4f" % (self.optimizer.param_groups[0]["lr"]))
						print('entering testing ')
						eval_results = self.evaluate()
						print('*'*40)
						print("Evaluating on test set: iteration {} of epoch {}: mae_loss: {} Regret: {}\t Regret2: {}\t Objective: {} \t Objective2: {}".
						                   format(self.iter, self.epochs, eval_results['mae_loss'],eval_results['regret'],eval_results['regret2'], eval_results['owa_obj'], eval_results['owa_obj2']))
						
						val_mae_loss.append(eval_results['mae_loss']) 
						val_regret.append(eval_results['regret'])
						val_obj.append(eval_results['owa_obj'])
						val_loss_lst.append(eval_results['loss_z'])

						val_mae_loss2.append(eval_results['mae_loss2']) 
						val_regret2.append(eval_results['regret2'])
						val_obj2.append(eval_results['owa_obj2'])
						val_loss_lst2.append(eval_results['loss_z2'])

						if eval_results['owa_obj'] > (self.best+ 1e-5):	
							self.best = eval_results['owa_obj']
							self.step= 0 
							self.best_model = copy.deepcopy(self.model)
						else: 
							self.step+=1
							print('patience', self.step)

						self.scheduler.step(z_obj) #M

				self.iter +=1
				# measure elapsed time
				end = time.time()
				batch_time.update(time.time() - end)


		return {"train_loss":train_loss_lst,
				"train_mae_loss": train_mae_loss, 
				"train_regret": train_regret, 
				"train_obj": train_obj, 
				"val_loss": val_loss_lst, 
				"val_mae_loss": val_mae_loss, 
				"val_regret": val_regret, 
				"val_obj": val_obj, 
				"val_loss2": val_loss_lst2, 
				"val_mae_loss2": val_mae_loss2, 
				"val_regret2": val_regret2, 
				"val_obj2": val_obj2, 
				"val_opt_obj": eval_results['z_opt'],
		}

	def evaluate(self, is_test=False):
		avg_metrics = {}
		self.model.eval()
		batch_path_lens_task, batch_owa_path_lens_task= [], []
		with torch.no_grad():
			x_test, test_true_cost, x_opt_sol, y_opt_sol, z_opt_sol,x_opt_sol_sum, y_opt_sol_sum, z_opt_sol_sum  = self.test_iterator.dataset.tensors
			if is_test: 
				pred_cost = self.best_model(x_test)
			else:
				pred_cost = self.model(x_test)
			if self.trainer_name == 'OWA2Stage': 
				lst_x_sol = []
				loss = self.MSELoss(pred_cost, test_true_cost) #pred cost: B x M x N
				for dat in pred_cost:
					dat_np = dat.detach().numpy().astype(float)
					cur_x_sol, _, _ = self.solver(dat_np)
					lst_x_sol.append(cur_x_sol)

				x_pred_sol = np.stack(lst_x_sol)
				y_pred_sol = np.einsum("imn, in->im", test_true_cost.numpy().astype(float), x_pred_sol)
				z_pred_sol = np.einsum("m, im-> i", self.w_gini.numpy(), np.sort(y_pred_sol, -1))
				loss_x = np.abs(x_pred_sol - x_opt_sol.numpy()).mean()
				loss_y = np.abs(y_pred_sol - y_opt_sol.numpy()).mean()
				loss_z = np.abs(z_opt_sol.numpy() - z_pred_sol).mean()
				z_pred_sol2 = z_pred_sol
				loss_x2 = loss_x
				loss_y2 = loss_y
				loss_z2 = loss_z

			else: 
				x_pred_sol = self.solver(pred_cost.squeeze())
				y_pred_sol = torch.einsum("imn, in->im", test_true_cost, x_pred_sol)
				z_pred_sol = torch.einsum("m, im-> i", self.w_gini, y_pred_sol.sort(dim=-1).values).numpy()

				if self.trainer_name in ['OWAPGDSubGrad', 'OWAPGDMoreauGrad', 'OWALPNorm','OWAFoldedSubGrad', 'OWALPNormMoreauGrad']: 
					start= time.time()
					x_pred_sol2 = self.solver_test(pred_cost.squeeze())
					print('pred_cost.shape', pred_cost.shape)
					end = time.time() - start
					y_pred_sol2 = torch.einsum("imn, in->im", test_true_cost, x_pred_sol2)
					z_pred_sol2 = torch.einsum("m, im-> i", self.w_gini, y_pred_sol2.sort(dim=-1).values).numpy()

				else:
					x_pred_sol2 = x_pred_sol
					y_pred_sol2 = y_pred_sol
					z_pred_sol2 = z_pred_sol

				# print('diff', torch.abs(x_opt_sol - x_pred_sol).sum(-1))

				if self.trainer_name == 'SumSurrogateQP':
					loss_x = np.abs(x_pred_sol.detach().numpy() - x_opt_sol_sum.numpy()).mean()
					loss_x2 = np.abs(x_pred_sol2.detach().numpy() - x_opt_sol_sum.numpy()).mean()
				else: 
					loss_x = np.abs(x_pred_sol.detach().numpy() - x_opt_sol.numpy()).mean()
					loss_x2 = np.abs(x_pred_sol2.detach().numpy() - x_opt_sol.numpy()).mean()

				loss_y = torch.abs(y_opt_sol - y_pred_sol).mean()
				loss_z = torch.abs(z_opt_sol - z_pred_sol)
				loss_y2 = torch.abs(y_opt_sol - y_pred_sol2).mean()
				loss_z2= torch.abs(z_opt_sol - z_pred_sol2)

			print('Test: z_pred_sol', z_pred_sol.mean(), z_opt_sol.mean())
			print('check l1 error:'  )

			print('x mae: ', torch.abs(x_opt_sol - x_pred_sol).sum(-1).mean())

			# loss_y = ((y_opt_sol - y_pred_sol).mean()).numpy()
			z_obj = z_pred_sol.mean()
			avg_metrics["mae_loss"]=loss_x         
			avg_metrics["regret"]=loss_y.mean()
			avg_metrics["owa_obj"]=z_obj
			avg_metrics["loss_z"]=loss_z.mean()
			avg_metrics["mae_loss2"]=loss_x2         
			avg_metrics["regret2"]=loss_y2.mean()
			avg_metrics["owa_obj2"]=z_pred_sol2.mean()
			avg_metrics["loss_z2"]=loss_z2.mean()
			avg_metrics["z_opt"] = z_opt_sol.mean()


		self.model.train()
		return avg_metrics






