import numpy as np
import torch
import torch.nn as nn
import mo_gymnasium as mo_gym
from gymnasium import spaces
from copy import deepcopy
import random
import sys
import matplotlib.pyplot as plt
from deap.tools._hypervolume import hv as deap_hv
from concurrent.futures import ProcessPoolExecutor
import os

from MPMORL.evaluate.pusher_v4 import env_evaluate
from Safe_Multi_Objective_MuJoCo.config import get_config
from Safe_Multi_Objective_MuJoCo.mujoco_config import get_env_mujoco_config

args = get_config().parse_args()
args.env_name = "Pusher-v4"
env_config_args = deepcopy(args)

env = get_env_mujoco_config(args)
num_inputs = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]

NEURON_COUNT = 64


class Policy(nn.Module):
	def __init__(self, num_inputs, num_outputs):
		super(Policy, self).__init__()
		self.affine1 = nn.Linear(num_inputs, NEURON_COUNT)
		self.affine2 = nn.Linear(NEURON_COUNT, NEURON_COUNT)

		self.action_mean = nn.Linear(NEURON_COUNT, num_outputs)
		self.action_mean.weight.data.mul_(0.1)
		self.action_mean.bias.data.mul_(0.0)

		self.action_log_std = nn.Parameter(torch.zeros(1, num_outputs))

	def forward(self, x):
		x = torch.relu(self.affine1(x))
		x = torch.relu(self.affine2(x))
		action_mean = self.action_mean(x)
		return action_mean


def compute_mphv_mpsp(rewards, ref_point=[1.0, 1.0], weights=[0.5, 0.5]):

	if len(rewards) == 0:
		return 0.0, 0.0

	party1 = rewards[:, :2]
	party2 = rewards[:, 2:]

	try:
		hv1 = deap_hv.hypervolume(party1.tolist(), [1.0, 1.0])
		hv2 = deap_hv.hypervolume(party2.tolist(), [1.0, 1.0])
		mphv = weights[0] * hv1 + weights[1] * hv2
	except:
		mphv = 0.0

	def compute_sp(front):
		front = np.array(front)
		M, m = front.shape
		if M <= 1:
			return 0.0
		sp_sum = 0.0
		for k in range(m):
			sorted_vals = np.sort(front[:, k])
			for j in range(M - 1):
				diff = sorted_vals[j + 1] - sorted_vals[j]
				sp_sum += diff ** 2
		return sp_sum / (M - 1)

	try:
		sp1 = compute_sp(party1)
		sp2 = compute_sp(party2)
		mpsp = weights[0] * sp1 + weights[1] * sp2
	except:
		mpsp = 0.0

	return mphv, mpsp


def evaluate_single_offspring(task_args):
	theta_target, theta_r1, theta_r2, theta_r3, F, CR, config_args = task_args

	mutant_theta = differential_mutation(theta_r1, theta_r2, theta_r3, F=F)
	mutant_theta = torch.clamp(mutant_theta, -1.0, 1.0)

	dim = len(theta_target)
	j_rand = random.randint(0, dim - 1)
	trial_theta = theta_target.clone()
	for j in range(dim):
		if random.random() < CR or j == j_rand:
			trial_theta[j] = mutant_theta[j]
	trial_theta = torch.clamp(trial_theta, -1.0, 1.0)

	worker_env = get_env_mujoco_config(config_args)
	w_num_inputs = worker_env.observation_space.shape[0]
	w_num_actions = worker_env.action_space.shape[0]

	trial_model = Policy(w_num_inputs, w_num_actions)
	set_flat_params(trial_model, trial_theta)

	r1, r2, r3, r4 = env_evaluate(worker_env, trial_model, episodes=1)
	worker_env.close()

	return trial_theta, np.array([r1, r2, r3, r4])


def plot_pareto_fronts_multi_view(population, best_fronts):
	rewards = np.array([ind[2] for ind in population])
	d1_rewards = rewards[:, :2]
	d2_rewards = rewards[:, 2:]

	plt.figure(figsize=(14, 6))

	# --- DM1 ---
	d1_colors = ['red' if i in best_fronts[0] else 'gray' for i in range(len(population))]
	plt.subplot(1, 2, 1)
	plt.scatter(d1_rewards[:, 0], d1_rewards[:, 1], c=d1_colors)
	plt.title("Decision Maker 1: Objectives [r1, r2]")
	plt.xlabel("r1")
	plt.ylabel("r2")

	# --- DM2 ---
	d2_colors = ['red' if i in best_fronts[1] else 'gray' for i in range(len(population))]
	plt.subplot(1, 2, 2)
	plt.scatter(d2_rewards[:, 0], d2_rewards[:, 1], c=d2_colors)
	plt.title("Decision Maker 2: Objectives [r3, r4]")
	plt.xlabel("r3")
	plt.ylabel("r4")

	plt.tight_layout()
	plt.show()

def get_flat_params(model):
	return torch.cat([p.data.view(-1) for p in model.parameters()])


def set_flat_params(model, flat_params):
	idx = 0
	for p in model.parameters():
		n = p.numel()
		p.data.copy_(flat_params[idx:idx + n].view_as(p))
		idx += n


def clone_model(model):
	clone = deepcopy(model)
	clone.load_state_dict(model.state_dict())
	return clone


def crowding_distance(values):
	size = len(values)
	distances = np.zeros(size)
	values = np.array(values)
	num_objectives = values.shape[1]
	for m in range(num_objectives):
		indices = np.argsort(values[:, m])
		distances[indices[0]] = distances[indices[-1]] = float('inf')
		for i in range(1, size - 1):
			distances[indices[i]] += (
				values[indices[i + 1], m] - values[indices[i - 1], m]
			) / (values[indices[-1], m] - values[indices[0], m] + 1e-8)
	return distances


def differential_mutation(theta_r1, theta_r2, theta_r3, F=0.5):
	return theta_r1 + F * (theta_r2 - theta_r3)


def epsilon_dominates(a, ref, eps, env_signs, dm_idx):

	non_strict, strict = True, False

	if dm_idx == 0:
		obj_idx = [0, 1]
	else:
		obj_idx = [2, 3]

	for j in obj_idx:
		if env_signs[j] == 1:
			cond_non_strict = (a[j] <= eps[dm_idx] * ref[j])
			cond_strict = (a[j] < eps[dm_idx] * ref[j])
		else:
			cond_non_strict = (eps[dm_idx] * a[j] <= ref[j])
			cond_strict = (eps[dm_idx] * a[j] < ref[j])

		if not cond_non_strict:
			non_strict = False
		if cond_strict:
			strict = True

	return non_strict and strict


def multi_party_nsde(population, env_config, save_dir, pop_size=100, F=0.5, CR=0.9,
					 iterations=500, mut_num=1, num_workers=20,
					 env_signs=None, eps_decay=0.01, patience=10,
					 joint_threshold=10, eps_min=1e-6):

	if not population:
		for _ in range(pop_size):
			model = Policy(num_inputs, num_actions)
			theta = get_flat_params(model)
			r1, r2, r3, r4 = env_evaluate(env, model, episodes=1)
			population.append((model, theta, np.array([r1, r2, r3, r4])))


	eps = [0.5, 0.5]
	ref = np.max([ind[2] for ind in population], axis=0)*10

	print(f"初始参考解 ref = {ref}")

	no_improve_count = [0, 0]
	best_fronts = [[], []]

	mphv_history = []
	mpsp_history = []

	with ProcessPoolExecutor(max_workers=num_workers) as executor:
		for gen in range(iterations):
			print(f"\n=== NSDE Generation {gen + 1}/{iterations} (using {num_workers} workers) ===")

			tasks = []
			for i in range(len(population)):
				for _ in range(mut_num):
					indices = list(range(len(population)))
					indices.remove(i)
					ran1, ran2, ran3 = random.sample(indices, 3)

					theta_target = population[i][1]
					theta_r1 = population[ran1][1]
					theta_r2 = population[ran2][1]
					theta_r3 = population[ran3][1]

					task_args = (theta_target, theta_r1, theta_r2, theta_r3, F, CR, env_config)
					tasks.append(task_args)

			results = list(executor.map(evaluate_single_offspring, tasks))

			offspring = []
			base_model = population[0][0]
			for trial_theta, rewards_array in results:
				trial_model = clone_model(base_model)
				set_flat_params(trial_model, trial_theta)
				offspring.append((trial_model, trial_theta, rewards_array))

			combined = population + offspring
			rewards = [ind[2] for ind in combined]

			first_layer = [
				idx for idx, r in enumerate(rewards)
				if epsilon_dominates(r, ref, eps, env_signs, 0)
				   and epsilon_dominates(r, ref, eps, env_signs, 1)
			]

			if first_layer:
				print(f"[ε-dominance] Found joint ε-dominant candidates: count = {len(first_layer)}")
				best_fronts[0] = first_layer
				best_fronts[1] = first_layer

				if len(first_layer) >= joint_threshold:
					old_eps = eps.copy()
					eps[0] = max(eps_min, eps[0] * eps_decay)
					eps[1] = max(eps_min, eps[1] * eps_decay)
					no_improve_count = [0, 0]
					print(
						f"[ε-dominance] Joint threshold reached ({joint_threshold}), shrink eps {old_eps} -> {eps}")
				else:
					print(
						f"[ε-dominance] Joint candidates < threshold ({joint_threshold}), do NOT shrink eps. Current eps = {eps}")
			else:
				for dm_idx in [0, 1]:
					new_front = [
						idx for idx, r in enumerate(rewards)
						if epsilon_dominates(r, ref, eps, env_signs, dm_idx)
					]
					if new_front:
						best_fronts[dm_idx] = new_front
						no_improve_count[dm_idx] = 0
						print(
							f"[ε-dominance] DM{dm_idx + 1} has ε-dominant solutions (not joint), ε={eps[dm_idx]:.4f}, count={len(new_front)}")
					else:
						no_improve_count[dm_idx] += 1
						print(f"[ε-dominance] DM{dm_idx + 1} No improvement, count={no_improve_count[dm_idx]}")

			elite_joint, elite_joint_idx = [], []
			elite_single, elite_single_idx = [], []

			for idx, r in enumerate(rewards):
				dom_by_dm1 = epsilon_dominates(r, ref, eps, env_signs, 0)
				dom_by_dm2 = epsilon_dominates(r, ref, eps, env_signs, 1)

				if dom_by_dm1 and dom_by_dm2:
					elite_joint.append(combined[idx])
					elite_joint_idx.append(idx)
				elif dom_by_dm1 or dom_by_dm2:
					elite_single.append(combined[idx])
					elite_single_idx.append(idx)

			selected_elite = []

			if elite_joint:
				elite_joint_rewards = [ind[2] for ind in elite_joint]
				elite_joint_dist = crowding_distance(elite_joint_rewards)
				elite_joint_with_dist = list(zip(elite_joint, elite_joint_dist))
				elite_joint_with_dist.sort(key=lambda x: -x[1])
				selected_elite.extend([item[0] for item in elite_joint_with_dist])

			remaining = pop_size - len(selected_elite)
			if remaining > 0 and elite_single:
				elite_single_rewards = [ind[2] for ind in elite_single]
				elite_single_dist = crowding_distance(elite_single_rewards)
				elite_single_with_dist = list(zip(elite_single, elite_single_dist))
				elite_single_with_dist.sort(key=lambda x: -x[1])
				to_add = [item[0] for item in elite_single_with_dist[:remaining]]
				selected_elite.extend(to_add)

			remaining = pop_size - len(selected_elite)
			if remaining > 0:
				non_elite = [ind for i, ind in enumerate(combined)
							 if i not in elite_joint_idx and i not in elite_single_idx]
				if non_elite:
					non_elite_rewards = [ind[2] for ind in non_elite]
					non_elite_distances = crowding_distance(non_elite_rewards)
					non_elite_with_dist = list(zip(non_elite, non_elite_distances))
					non_elite_with_dist.sort(key=lambda x: -x[1])
					to_add = [item[0] for item in non_elite_with_dist[:remaining]]
					selected_elite.extend(to_add)

			population = selected_elite[:pop_size]

			if (gen + 1) % 50 == 0:

				feasible_indices = set()
				for idx, r in enumerate(rewards):
					if (epsilon_dominates(r, ref, eps, env_signs, 0) or
							epsilon_dominates(r, ref, eps, env_signs, 1)):
						feasible_indices.add(idx)

				if feasible_indices:
					feasible_rewards = np.array([rewards[i] for i in feasible_indices])
				else:
					feasible_rewards = np.array(rewards)

				party1 = feasible_rewards[:, :2]
				party2 = feasible_rewards[:, 2:]
				ref_point = [1.0, 1.0]
				weights = [0.5, 0.5]

				try:
					hv1 = deap_hv.hypervolume(party1.tolist(), ref_point)
					hv2 = deap_hv.hypervolume(party2.tolist(), ref_point)
					mphv = weights[0] * hv1 + weights[1] * hv2
				except:
					mphv = 0.0

				def compute_sp(front):
					front = np.array(front)
					M, m = front.shape
					if M <= 1:
						return 0.0
					sp_sum = 0.0
					for k in range(m):
						sorted_vals = np.sort(front[:, k])
						for j in range(M - 1):
							diff = sorted_vals[j + 1] - sorted_vals[j]
							sp_sum += diff ** 2
					return sp_sum / (M - 1)

				try:
					sp1 = compute_sp(party1)
					sp2 = compute_sp(party2)
					mpsp = weights[0] * sp1 + weights[1] * sp2
				except:
					mpsp = 0.0

				mphv_history.append(mphv)
				mpsp_history.append(mpsp)
				print(f"[MPHV/MPSP] Gen {gen + 1}: MPHV={mphv:.6f}, MPSP={mpsp:.6f}")

	final_front = [rewards[i] for i in best_fronts[0]]

	return population, best_fronts, eps


for exp in range(0,6):
	if __name__ == "__main__":
		SEED = 42
		np.random.seed(SEED)
		torch.manual_seed(SEED)
		random.seed(SEED)

		population = []
		env_signs = [1, 1, 1, 1]
		population, best_front, final_eps = multi_party_nsde(
			population=population,
			env_config=env_config_args,
			save_dir=save_dir,
			pop_size=100,
			iterations=500,
			CR=0.9,
			F=0.5,
			num_workers=20,
			env_signs=env_signs,
			eps_decay=0.99,
			patience=20,
			joint_threshold=60,
			eps_min=1e-6,
		)

		env.close()
