import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pykep as pk
from matplotlib.mlab import phase_spectrum
from torch.distributions import Normal
from torch.optim.lr_scheduler import StepLR, LambdaLR
from tqdm import trange

from utils3 import generate_1_collision, plot_rewards, calculate_distances, to_serializable, generate_collisions
from simulator8_multi import Simulator, propagate
from torch.distributions import MultivariateNormal
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import torch
import torch.nn as nn
import torch.nn.functional as F


class PhysicsAwareAttention(nn.Module):
    def __init__(self, feature_dim, num_heads=1,
                 bias_mlp_hidden_dim=32, bias_mlp_output_dim=1,
                 bias_scale=0.1, learned_bias_scale=True, num_physics_features=4):
        """
        Args:
            Feature_im (int): The dimension of each fragment feature vector after encoding (e.g., 64)
            Num_ heads (int): The number of multi head attention heads
            Bias_mlp_hidden_im (int): Used to calculate the MLP hidden layer dimension of the physical bias term
            Bias_mlp0utput_im (int): Used to calculate the MLP output layer dimension of the physical bias term
            Bias_scale (float): The initial scaling factor of the physical bias term
            Learned-bias_scale (boolean): Whether to make the bias scaling factor learnable
            Num_physics_features (int): The dimension of the physical feature vector (e.g., [dist, time, v-rel, p-roll])
        """
        super().__init__()
        self.feature_dim = feature_dim
        self.num_heads = num_heads

        # Define the Q, K, V linear layers required for multi head self attention
        self.query = nn.Linear(feature_dim, feature_dim)
        self.key = nn.Linear(feature_dim, feature_dim)
        self.value = nn.Linear(feature_dim, feature_dim)

        # Learnable scaling factor for linear physical bias term
        # We can define one for each physical feature, simplified here as an overall scaling factor
        self.gamma = nn.Parameter(torch.ones(num_physics_features))

        # MLP bias
        # self.bias_mlp = nn.Sequential(
        #     nn.Linear(num_physics_features, bias_mlp_hidden_dim),
        #     nn.ReLU(),
        #     nn.Linear(bias_mlp_hidden_dim, bias_mlp_output_dim)
        # )

        #The total scaling factor of the physical bias term, which can be learned or fixed
        if learned_bias_scale:
            self.bias_scale = nn.Parameter(torch.tensor(bias_scale))
        else:
            self.register_buffer('bias_scale', torch.tensor(bias_scale))

    def forward(self, encoded_features, physics_features):
        """
        Args:
        Encoded_features (Tensor): Fragment encoding features with shapes of [N, D]
        Physics_features (Tensor): Physical features with the shape [N, num_physics_features]
        """
        N, D = encoded_features.size()  #N: Number of fragments, D: Feature dimension

        # Step 1: Q, K, V projection of self attention mechanism
        Q = self.query(encoded_features)  # [N, D]
        K = self.key(encoded_features)  # [N, D]
        V = self.value(encoded_features)  # [N, D]

        # Step 2: Calculate self attention logits
        # QK^T / sqrt(D)
        attention_logits = torch.matmul(Q, K.transpose(-2, -1)) / (D ** 0.5)

        # Step 3: Calculate the physical bias term
        # Linear bias
        linear_bias = -torch.sum(self.gamma * physics_features, dim=-1)  # [N]
        linear_bias = linear_bias.unsqueeze(0).expand(N, N)  # 扩展为 [N, N]

        # MLP bias
        # mlp_bias = self.bias_mlp(physics_features).squeeze(-1)  # [N]
        # mlp_bias = mlp_bias.unsqueeze(0).expand(N, N)  # 扩展为 [N, N]

        physics_bias = linear_bias

        # Step 4: Mixed logits
        # Add physical bias term to attention logits
        final_logits = attention_logits + self.bias_scale * physics_bias

        # Step 5: cal attention
        attention_weights = F.softmax(final_logits, dim=-1)

        # Step 6: fuse
        fused_vector = torch.matmul(attention_weights, V)  # [N, D]

        # pool
        final_representation = torch.mean(fused_vector, dim=0)  # [D]
        # final_representation, _ = torch.max(fused_vector, dim=0)  # [D]

        return final_representation, attention_weights


#Improved encoding layer with residual connections added
class ResidualEncoder(nn.Module):
    def __init__(self, input_dim, output_dim=64):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, output_dim)
        self.fc2 = nn.Linear(output_dim, output_dim)
        self.tanh = nn.Tanh()

        if input_dim != output_dim:
            self.adjust_dim = nn.Linear(input_dim, output_dim)
        else:
            self.adjust_dim = None

    def forward(self, x):
        identity = x

        out = self.tanh(self.fc1(x))
        out = self.fc2(out)

        if self.adjust_dim:
            identity = self.adjust_dim(identity)

        out += identity
        return self.tanh(out)


# ==== MultiHead Actor With Done ====
class Actor(nn.Module):
    def __init__(self, action_std, state_dim_ship=9, state_dim_debris=8, action_dim=3, num_physics_features=2):
        super().__init__()
        self.input_dim = state_dim_ship + state_dim_debris
        self.num_physics_features = num_physics_features

        # self.encoder = nn.Sequential(
        #     nn.Linear(self.input_dim, 64),
        #     nn.Tanh()
        # )
        self.encoder = ResidualEncoder(state_dim_ship + state_dim_debris, 64)

        self.head = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, action_dim)
        )

        self.attention_fusion = PhysicsAwareAttention(
            feature_dim=64, num_physics_features=num_physics_features
        )

        self.action_layer = nn.Sequential(
            nn.Linear(64, 32),
            nn.Tanh(),
            nn.Linear(32, action_dim),
            nn.Tanh()
        )

        self.done_layer = nn.Sequential(
            nn.Linear(64, 32),
            nn.Tanh(),
            nn.Linear(32, 1)
        )

        # nn.init.xavier_uniform_(self.encoder.weight)
        # nn.init.zeros_(self.encoder.bias)
        # nn.init.xavier_uniform_(self.action_layer.weight)
        # nn.init.zeros_(self.action_layer.bias)
        # nn.init.xavier_uniform_(self.done_layer.weight)
        # nn.init.zeros_(self.done_layer.bias)

        self.action_var = torch.full((action_dim,), action_std * action_std).to(device)
        self.action_std = action_std

    def forward(self, state_list, action_limit=1.0):
        """
        state_list: list of tensors，
         among which
        -State_ist [0] is the state of the main spacecraft [9]
        -State_ist [1:] is the state of each fragment [8]
        return:
        - output_action: tensor [action_dim]
        -Done_probe: tensor scalar [1]
        """
        ship_state = state_list[0]  # [9]
        ship_state_tensor = torch.tensor(ship_state, device=device)
        debris_states = state_list[1:]  # list of [8]

        if len(debris_states) == 0:
            dummy_action = torch.zeros(self.head[-1].out_features, device=device)
            return dummy_action, torch.tensor(1.0, device=device)

        inputs = []
        physics_features = []
        for debris_state in debris_states:
            debris_state_tensor = torch.tensor(debris_state, device=device).float()
            combined = torch.cat([ship_state_tensor, debris_state_tensor], dim=0)
            inputs.append(combined)

            # Assuming that the physical features are at the end of the fragment state vector, for example: [..., distance, time, v-rel, p_roll]
            # The index here needs to be adjusted according to your actual data format!
            physics_features.append(debris_state_tensor[-self.num_physics_features:])

        inputs = torch.stack(inputs)  # [N, 17]
        physics_features = torch.stack(physics_features)  # [N, num_physics_features]

        # Encoding+multi head action output
        encoded = self.encoder(inputs)  # [N, 64]
        # debris_actions = self.head(encoded)  # [N, action_dim]

        # Attention Fusion
        fused, attn_weights = self.attention_fusion(encoded, physics_features)  # [action_dim]

        #Action prediction
        output_action = self.action_layer(fused) * action_limit  # [action_dim]

        # Done prediction
        done_logit = self.done_layer(fused).squeeze(0)  # scalar
        done_prob = torch.sigmoid(done_logit)
        # done_prob = torch.tensor(0.0, dtype=torch.float32, device=device)

        return output_action, done_prob

    def act(self, state_list, action_limit=1.0):

        action_mean, done_prob = self.forward(state_list, action_limit)
        std = torch.ones_like(action_mean) * self.action_std
        dist = Normal(action_mean, std)

        # cov_mat = torch.diag(self.action_var).to(device)
        # dist = MultivariateNormal(action_mean, cov_mat)
        action_sample = dist.sample()
        action_logprob = dist.log_prob(action_sample)

        action = torch.clamp(action_sample, -action_limit, action_limit)
        return action, action_logprob, done_prob


# ==== MultiHead Critic ====
class Critic(nn.Module):
    def __init__(self, state_dim_ship=9, state_dim_debris=8, num_physics_features=2):
        super().__init__()
        self.input_dim = state_dim_ship + state_dim_debris
        self.num_physics_features = num_physics_features

        # self.encoder = nn.Sequential(
        #     nn.Linear(self.input_dim, 64),
        #     nn.Tanh()
        # )
        self.encoder = ResidualEncoder(state_dim_ship + state_dim_debris, 64)

        self.attention_fusion = PhysicsAwareAttention(
            feature_dim=64, num_physics_features=num_physics_features
        )

        self.value_layer = nn.Sequential(
            nn.Linear(64, 32),
            nn.Tanh(),
            nn.Linear(32, 1)
        )

        # nn.init.xavier_uniform_(self.encoder.weight)
        # nn.init.zeros_(self.encoder.bias)
        # nn.init.xavier_uniform_(self.value_layer.weight)
        # nn.init.zeros_(self.value_layer.bias)

    def forward(self, state_list):
        """
        State_dast: Status list
        return:
        - values: tensor [B]，
         Value estimation for each state
        """
        ship_state = state_list[0]  # [9]
        ship_state_tensor = torch.tensor(ship_state, device=device)
        debris_states = state_list[1:]  # list of [8]

        if len(debris_states) == 0:
            return torch.tensor(0.0, device)

        inputs = []
        physics_features = []
        for debris_state in debris_states:
            debris_state_tensor = torch.tensor(debris_state, device=device).float()
            combined = torch.cat([ship_state_tensor, debris_state_tensor], dim=0)
            inputs.append(combined)

            physics_features.append(debris_state_tensor[-self.num_physics_features:])

        inputs = torch.stack(inputs)  # [N, 17]
        physics_features = torch.stack(physics_features)  # [N, num_physics_features]

        encoded = self.encoder(inputs)  # [N, 64]
        fused, attn_weights = self.attention_fusion(encoded, physics_features)  # [64]

        value = self.value_layer(fused)  # scalar

        return value  # [B]


# ==== PPO Trainer ====
# ==== PPO Memory Class ====
class Memory:
    def __init__(self):
        self.states = []
        self.actions = []
        self.action_logprobs = []
        self.rewards = []
        self.done_logprobs = []
        self.done_labels = []
        self.running_rewards = []

    def store(self, state, reward, action, action_logprob, done_logprob, done_label):
        self.states.append(state)
        self.actions.append(action.detach())
        self.action_logprobs.append(action_logprob.detach())
        self.rewards.append(reward)
        self.done_logprobs.append(done_logprob.detach())
        self.done_labels.append(done_label)

    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.action_logprobs[:]
        del self.rewards[:]
        del self.done_logprobs[:]
        del self.done_labels[:]


# ==== PPO Agent with Multi-Epoch Update ====
class PPOAgent:
    def __init__(self, action_std, dV_limit, lr=3e-4, betas=(0.9, 0.999), gamma=0.99, lr_decay=0.9,
                 clip_eps=0.2, K_epochs=4, batch_size=200):
        self.actor = Actor(action_std).to(device)
        self.actor_old = Actor(action_std).to(device)
        self.critic = Critic().to(device)
        self.gamma = gamma
        self.optimizer = torch.optim.AdamW(list(self.actor.parameters()) + list(self.critic.parameters()), lr=lr, betas=betas)
        # self.optimizer = torch.optim.SGD(list(self.actor.parameters()) + list(self.critic.parameters()), lr=lr)

        # self.scheduler = StepLR(self.optimizer, step_size=20, gamma=lr_decay)
        min_lr_ratio = 1e-1
        linear_lambda = lambda step: max(min_lr_ratio, 1 - step / 10000)
        self.scheduler = LambdaLR(self.optimizer, lr_lambda=linear_lambda)

        self.clip_eps = clip_eps
        self.K_epochs = K_epochs
        self.batch_size = batch_size
        self.action_std = action_std
        self.dV_limit = dV_limit
        self.lr = lr
        self.beta = betas

    def update(self, memory, env_name='debris_collision'):
        # === Step 1: Discounted rewards ===
        rewards = []
        discounted_reward = 0
        for reward, done in zip(reversed(memory.rewards), reversed(memory.done_logprobs)):
            if done:
                discounted_reward = 0
            discounted_reward = reward + self.gamma * discounted_reward
            rewards.insert(0, discounted_reward)

        rewards = torch.tensor(rewards, dtype=torch.float32, device=device)
        # returns = rewards
        returns = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
        # returns = rewards - rewards.mean()

        # === Step 2: Process actions/log_probs ===
        states = memory.states  # list of state_lists
        actions = torch.stack([
            torch.from_numpy(a).float() if isinstance(a, np.ndarray) else a.float()
            for a in memory.actions
        ]).to(device)
        old_logprobs = torch.stack(memory.action_logprobs).to(device)

        # === Step 3: Estimate values and advantages ===
        with torch.no_grad():
            values = torch.tensor([
                self.critic(state).item() for state in states
            ], dtype=torch.float32, device=device)
        advantages = returns - values
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)
        # advantages = advantages - advantages.mean()

        # === Step 4: PPO update ===
        dataset_size = len(states)
        indices = list(range(dataset_size))

        for _ in range(self.K_epochs):
            random.shuffle(indices)
            for i in range(0, dataset_size, self.batch_size):
                batch_idx = indices[i:i + self.batch_size]
                b_states = [states[j] for j in batch_idx]
                b_actions = actions[batch_idx]
                b_logprobs_old = old_logprobs[batch_idx].sum(dim=1)
                b_returns = returns[batch_idx]
                b_advs = advantages[batch_idx]

                # === Actor forward ===
                action_preds, done_preds = zip(*[
                    self.actor(b_state, self.dV_limit) for b_state in b_states
                ])
                action_preds = torch.stack(action_preds)
                # action_preds_clamp = torch.clamp(action_preds, -self.dV_limit, self.dV_limit)
                done_preds = torch.stack(done_preds)


                # === Build distribution ===
                std = torch.ones_like(action_preds) * self.action_std  # self.action_std 是个超参
                dist = Normal(action_preds, std)
                logp = dist.log_prob(b_actions).sum(dim=1)
                entropy = dist.entropy().mean()

                ratios = torch.exp(logp - b_logprobs_old)
                surr1 = ratios * b_advs
                surr2 = torch.clamp(ratios, 1 - self.clip_eps, 1 + self.clip_eps) * b_advs
                policy_loss = -torch.min(surr1, surr2).mean()

                # === Critic forward ===
                value_preds = torch.stack([
                    self.critic(state) for state in b_states
                ]).squeeze()

                value_loss = F.mse_loss(value_preds, b_returns)

                # === Done loss ===
                done_labels = torch.tensor([memory.done_labels[j] for j in batch_idx],
                                           dtype=torch.float32, device=device)
                done_loss = F.binary_cross_entropy(done_preds.squeeze(), done_labels)

                # === Total loss and backprop ===
                loss = policy_loss + 0.5 * value_loss + 0.5 * done_loss - 0.01 * entropy
                # loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
                # loss = policy_loss + 0.5 * value_loss + 0.5 * done_loss

                self.optimizer.zero_grad()
                loss.backward()
                # policy_loss.backward(retain_graph=True)
                # value_loss.backward(retain_graph=True)
                # done_loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=1) #防止梯度爆炸
            self.optimizer.step()

        self.actor_old.load_state_dict(self.actor.state_dict())
        self.scheduler.step()
        self._save_all(memory, env_name)
        memory.clear()

        return {
            'policy_loss': policy_loss.item(),
            'value_loss': value_loss.item(),
            'done_loss': done_loss.item(),
            'entropy': entropy.item()
        }

    def _save_all(self, memory, env_name='debris_collision'):
        # torch.save(self.actor.state_dict(), f'./results/models-trained/PPO_multi_{env_name}.pth')
        torch.save({
            'actor': self.actor.state_dict(),
            'critic': self.critic.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            # 'step': current_step,
        }, f'./results/models-trained/PPO_multi_{env_name}.pth')

        with open('results/STAN_rewards.json', 'w') as f:
            json.dump([float(r) for r in memory.rewards], f)
        with open('results/STAN_running_rewards.json', 'w') as f:
            json.dump([float(r) for r in memory.running_rewards], f)
        memory_actions_list = [a.tolist() for a in memory.actions]
        with open('results/STAN_states.json', 'w') as f:
            json.dump(to_serializable(memory.states), f)
        with open('results/STAN_actions.json', 'w') as f:
            json.dump(memory_actions_list, f)


# ==== train class ====
class TrainPPO:
    def __init__(self, env, state_dim, initial_oscelement, action_dim=3, action_std=0.08, lr=0.0001, gamma=0.99,K_epochs=100, batch_size=100,
                 clip_eps=0.2,max_episodes=500, max_timesteps=3000, update_timestep=2000, solved_reward=300000, log_interval=20,
                 PROPAGATION_STEP=0.001, betas=(0.95, 0.999), load_path=None):

        self.env = env
        self.initial_oscelement = initial_oscelement
        self.PROPAGATION_STEP = PROPAGATION_STEP
        self.max_episodes = max_episodes
        self.max_timesteps = max_timesteps
        self.update_timestep = update_timestep
        self.solved_reward = solved_reward
        self.log_interval = log_interval
        self.collision_distance = 5
        self.dV_limit = 0.17 / 750 * PROPAGATION_STEP * 86400 / 3 ** 0.5
        self.ddV_limit = self.dV_limit * 0.1
        self.action_std = action_std * self.dV_limit

        self.memory = Memory()
        self.agent = PPOAgent(self.action_std, self.dV_limit, lr=lr, betas=betas, gamma=gamma, clip_eps=clip_eps,
                              K_epochs=K_epochs, batch_size=batch_size)

        if load_path:
            checkpoint = torch.load(load_path)
            self.agent.actor.load_state_dict(checkpoint['actor'])
            self.agent.actor_old.load_state_dict(checkpoint['actor'])
            self.agent.critic.load_state_dict(checkpoint['critic'])
            # self.agent.optimizer.load_state_dict(checkpoint['optimizer'])
            self.agent.actor.train()
            self.agent.critic.train()
            # step = checkpoint['step']
            # self.agent.actor.load_state_dict(torch.load(load_path))
            # self.agent.actor_old.load_state_dict(torch.load(load_path))
            # self.agent.actor.train()

    def reconstruct_env(self, collision_time):
        position = self.env.coords_by_epoch(collision_time)[0][0:3]
        env_start = collision_time - 0.04
        env_end = collision_time + 0.01
        self.env = generate_collisions(position, collision_time,  env_start, env_end, quantity=10, stddev=100, distance=100)
        self.initial_oscelement = self.env.protected.osculating_elements(pk.epoch(self.env.init_params["start_time"]))
        self.env.init_params = dict(protected=self.env.protected, debris=self.env.debris,
                                    start_time=env_start, end_time=env_end)

    def run(self, random_seed=None, env_name='debris_collision', reconstruct_eachturn=False):
        if random_seed:
            torch.manual_seed(random_seed)
            np.random.seed(random_seed)

        time_step = 0
        running_reward = 0
        avg_length = 0
        for i_episode in trange(self.max_episodes, desc="Training PPO Episodes"):
            time_tick = 0
            action = [0, 0, 0]
            min_distance_pre = [self.collision_distance] * len(self.env.debris)

            action_history = [
                torch.zeros(3, device=device) for _ in range(3)
            ]

            if reconstruct_eachturn and i_episode % 5==0:
                # print("reconstruct_each 2 turn")
                self.reconstruct_env(8901)
            else:
                self.env.reset()

            simulator_outer = Simulator(self.env, self.initial_oscelement, step=self.PROPAGATION_STEP,
                                        min_distance=min_distance_pre)

            for t in range(self.max_timesteps):
                time_step += 1
                env_outer_epoch = self.env.state['epoch']
                state_outer = np.array(self.env.get_state()['coord'], dtype=float)
                collision_now = calculate_distances(simulator_outer.env.get_state()['coord'])

                env_tmp = self.env.copy()
                simulator = Simulator(env_tmp, self.initial_oscelement, step=self.PROPAGATION_STEP,
                                      time_now=(time_tick * self.PROPAGATION_STEP + self.env.init_params["start_time"]),
                                      min_distance=min_distance_pre)

                state, reward, action_gpu, action_logprob, done_prob, done_label, Fail = simulator.run(self.agent.actor_old,
                                    env_outer_epoch, state_outer,self.dV_limit, action, action_history, collision_now=collision_now)
                action = action_gpu.detach().cpu().numpy().flatten()

                simulator_outer.curr_time += self.PROPAGATION_STEP
                propagate(simulator_outer.env, simulator_outer.curr_time)
                if simulator_outer.curr_time >= simulator_outer.end_time:
                    reward += 100
                action_final = action * (1-done_prob.detach().cpu().numpy())
                simulator_outer.doact(action_final)
                # elif done_prob.detach().cpu() < 0.5:
                #     simulator_outer.doact(action)

                self.memory.store(state, reward, action_gpu, action_logprob, done_prob, done_label)
                action_history.pop(0)
                action_history.append(action_gpu.detach())

                for i in range(len(env_tmp.debris)):
                    min_distance_pre[i] = env_tmp.debris[i].min_distance

                del env_tmp

                if time_step % self.update_timestep == 0:
                    loss_info = self.agent.update(self.memory, env_name)
                    print(f"\n[Episode {i_episode}] Losses: {loss_info}")
                    if self.agent.action_std > 0.001 * self.dV_limit:
                        self.agent.action_std *= 0.99
                    time_step = 0

                running_reward += reward
                time_tick += 1

                if simulator_outer.curr_time >= simulator_outer.end_time or Fail:
                    break

            avg_length += t

            if running_reward > (self.log_interval * self.solved_reward):
                print("########## Solved! ##########")
                torch.save(self.agent.actor.state_dict(), f'./results/models-trained/PPO_solved{env_name}.pth')
                break

            if i_episode % self.log_interval == 0 and i_episode != 0:
                avg_length = int(avg_length / self.log_interval)
                running_reward /= self.log_interval
                self.memory.running_rewards.append(running_reward)

                print(f'\nEpisode {i_episode} \t Avg Length: {avg_length} \t Avg Reward: {running_reward}')
                running_reward = 0
                avg_length = 0

        plot_rewards(self.memory.rewards, 'rewards')
        plot_rewards(self.memory.running_rewards, 'running rewards')
        print("FINISHED")
