import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pykep as pk
from torch.distributions import Normal
from torch.optim.lr_scheduler import StepLR, LambdaLR
from tqdm import trange
from utils3 import generate_collisions, plot_rewards, calculate_distances, to_serializable
from simulator8_multi import Simulator, propagate
from torch.distributions import MultivariateNormal
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# ==== MultiHead Actor With Done ====
class LSTMActor(nn.Module):
    def __init__(self, action_std, state_dim_ship=9, state_dim_debris=8, action_dim=3):
        super().__init__()
        self.input_dim = state_dim_ship + state_dim_debris
        self.hidden_dim = 64

        self.encoder = nn.Linear(self.input_dim, self.hidden_dim)
        self.lstm = nn.LSTM(input_size=self.hidden_dim, hidden_size=self.hidden_dim, batch_first=True)

        self.action_layer = nn.Sequential(
            nn.Linear(self.hidden_dim, 32),
            nn.Tanh(),
            nn.Linear(32, action_dim),
            nn.Tanh()
        )

        # self.done_layer = nn.Sequential(
        #     nn.Linear(self.hidden_dim, 32),
        #     nn.Tanh(),
        #     nn.Linear(32, 1)
        # )

        self.action_std = action_std
        self.action_var = torch.full((action_dim,), action_std ** 2).to(device)

    def forward(self, state_list, action_limit=1.0):
        ship_state_tensor = torch.tensor(state_list[0], device=device)
        debris_states = state_list[1:]

        if len(debris_states) == 0:
            dummy_action = torch.zeros(self.action_layer[-1].out_features, device=device)
            return dummy_action, torch.tensor(1.0, device=device)

        # [N, 17]
        sequence = []
        for debris_state in debris_states:
            debris_tensor = torch.tensor(debris_state, device=device)
            combined = torch.cat([ship_state_tensor, debris_tensor], dim=0)
            sequence.append(combined)
        sequence = torch.stack(sequence).float().unsqueeze(0)  # [1, N, 17]

        encoded = self.encoder(sequence)  # [1, N, 64]
        lstm_out, (h_n, _) = self.lstm(encoded)  # h_n: [1, 1, 64]

        fused = h_n.squeeze(0).squeeze(0)  # [64]

        output_action = self.action_layer(fused) * action_limit
        # done_logit = self.done_layer(fused).squeeze(0)
        # done_prob = torch.sigmoid(done_logit)
        done_prob = torch.tensor(0.0, dtype=torch.float32, device=device)

        return output_action, done_prob

    def act(self, state_list, action_limit=1.0):
        action_mean, done_prob = self.forward(state_list, action_limit)
        std = torch.ones_like(action_mean) * self.action_std
        dist = Normal(action_mean, std)

        action_sample = dist.sample()
        action_logprob = dist.log_prob(action_sample)

        action = torch.clamp(action_sample, -action_limit, action_limit)
        return action, action_logprob, done_prob


# ==== MultiHead Critic ====
class LSTMCritic(nn.Module):
    def __init__(self, state_dim_ship=9, state_dim_debris=8):
        super().__init__()
        self.input_dim = state_dim_ship + state_dim_debris
        self.hidden_dim = 64

        self.encoder = nn.Linear(self.input_dim, self.hidden_dim)
        self.lstm = nn.LSTM(input_size=self.hidden_dim, hidden_size=self.hidden_dim, batch_first=True)

        self.value_layer = nn.Sequential(
            nn.Linear(self.hidden_dim, 32),
            nn.Tanh(),
            nn.Linear(32, 1)
        )

    def forward(self, state_list):
        ship_state_tensor = torch.tensor(state_list[0], device=device)
        debris_states = state_list[1:]

        if len(debris_states) == 0:
            return torch.tensor(0.0, device=device)

        # [N, 17]
        sequence = []
        for debris_state in debris_states:
            debris_tensor = torch.tensor(debris_state, device=device)
            combined = torch.cat([ship_state_tensor, debris_tensor], dim=0)
            sequence.append(combined)
        sequence = torch.stack(sequence).float().unsqueeze(0)  # [1, N, 17]

        encoded = self.encoder(sequence)  # [1, N, 64]
        lstm_out, (h_n, _) = self.lstm(encoded)

        fused = h_n.squeeze(0).squeeze(0)  # [64]

        value = self.value_layer(fused)  # scalar
        return value

# ==== PPO Trainer ====
# ==== PPO Memory Class ====
class Memory:
    def __init__(self):
        self.states = []
        self.actions = []
        self.action_logprobs = []
        self.rewards = []
        self.done_logprobs = []
        self.done_labels = []
        self.running_rewards = []

    def store(self, state, reward, action, action_logprob, done_logprob, done_label):
        self.states.append(state)
        self.actions.append(action.detach())
        self.action_logprobs.append(action_logprob.detach())
        self.rewards.append(reward)
        self.done_logprobs.append(done_logprob.detach())
        self.done_labels.append(done_label)

    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.action_logprobs[:]
        del self.rewards[:]
        del self.done_logprobs[:]
        del self.done_labels[:]


# ==== LSTM Agent with Multi-Epoch Update ====
class LSTMAgent:
    def __init__(self, action_std, dV_limit, lr=3e-4, betas=(0.9, 0.999), gamma=0.99, lr_decay=0.9,
                 clip_eps=0.2, K_epochs=4, batch_size=200):
        self.actor = LSTMActor(action_std).to(device)
        self.actor_old = LSTMActor(action_std).to(device)
        self.critic = LSTMCritic().to(device)
        self.gamma = gamma
        self.optimizer = torch.optim.AdamW(list(self.actor.parameters()) + list(self.critic.parameters()), lr=lr, betas=betas)
        # self.optimizer = torch.optim.SGD(list(self.actor.parameters()) + list(self.critic.parameters()), lr=lr)

        # self.scheduler = StepLR(self.optimizer, step_size=20, gamma=lr_decay)
        min_lr_ratio = 1e-1
        linear_lambda = lambda step: max(min_lr_ratio, 1 - step / 10000)
        self.scheduler = LambdaLR(self.optimizer, lr_lambda=linear_lambda)

        self.clip_eps = clip_eps
        self.K_epochs = K_epochs
        self.batch_size = batch_size
        self.action_std = action_std
        self.dV_limit = dV_limit
        self.lr = lr
        self.beta = betas

    def update(self, memory, env_name='debris_collision'):
        # === Step 1: Discounted rewards ===
        rewards = []
        discounted_reward = 0
        for reward, done in zip(reversed(memory.rewards), reversed(memory.done_logprobs)):
            if done:
                discounted_reward = 0
            discounted_reward = reward + self.gamma * discounted_reward
            rewards.insert(0, discounted_reward)

        rewards = torch.tensor(rewards, dtype=torch.float32, device=device)
        # returns = rewards
        returns = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
        # returns = rewards - rewards.mean()

        # === Step 2: Process actions/log_probs ===
        states = memory.states  # list of state_lists
        actions = torch.stack([
            torch.from_numpy(a).float() if isinstance(a, np.ndarray) else a.float()
            for a in memory.actions
        ]).to(device)
        old_logprobs = torch.stack(memory.action_logprobs).to(device)

        # === Step 3: Estimate values and advantages ===
        with torch.no_grad():
            values = torch.tensor([
                self.critic(state).item() for state in states
            ], dtype=torch.float32, device=device)
        advantages = returns - values
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)
        # advantages = advantages - advantages.mean()

        # === Step 4: PPO update ===
        dataset_size = len(states)
        indices = list(range(dataset_size))

        for _ in range(self.K_epochs):
            random.shuffle(indices)
            for i in range(0, dataset_size, self.batch_size):
                batch_idx = indices[i:i + self.batch_size]
                b_states = [states[j] for j in batch_idx]
                b_actions = actions[batch_idx]
                b_logprobs_old = old_logprobs[batch_idx].sum(dim=1)
                b_returns = returns[batch_idx]
                b_advs = advantages[batch_idx]

                # === Actor forward ===
                action_preds, done_preds = zip(*[
                    self.actor(b_state, self.dV_limit) for b_state in b_states
                ])
                action_preds = torch.stack(action_preds)
                # action_preds_clamp = torch.clamp(action_preds, -self.dV_limit, self.dV_limit)
                done_preds = torch.stack(done_preds)


                # === Build distribution ===
                std = torch.ones_like(action_preds) * self.action_std
                dist = Normal(action_preds, std)
                logp = dist.log_prob(b_actions).sum(dim=1)
                entropy = dist.entropy().mean()

                ratios = torch.exp(logp - b_logprobs_old)
                surr1 = ratios * b_advs
                surr2 = torch.clamp(ratios, 1 - self.clip_eps, 1 + self.clip_eps) * b_advs
                policy_loss = -torch.min(surr1, surr2).mean()

                # === Critic forward ===
                value_preds = torch.stack([
                    self.critic(state) for state in b_states
                ]).squeeze()

                value_loss = F.mse_loss(value_preds, b_returns)

                # === Done loss ===
                done_labels = torch.tensor([memory.done_labels[j] for j in batch_idx],
                                           dtype=torch.float32, device=device)
                done_loss = F.binary_cross_entropy(done_preds.squeeze(), done_labels)

                # === Total loss and backprop ===
                loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
                # loss = policy_loss + 0.5 * value_loss + 0.5 * done_loss

                self.optimizer.zero_grad()
                loss.backward()
                # policy_loss.backward(retain_graph=True)
                # value_loss.backward(retain_graph=True)
                # done_loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=1)
            self.optimizer.step()

        self.actor_old.load_state_dict(self.actor.state_dict())
        self.scheduler.step()
        self._save_all(memory, env_name)
        memory.clear()

        return {
            'policy_loss': policy_loss.item(),
            'value_loss': value_loss.item(),
            'done_loss': done_loss.item(),
            'entropy': entropy.item()
        }

    def _save_all(self, memory, env_name='debris_collision'):
        # torch.save(self.actor.state_dict(), f'./results/models-trained/PPO_multi_{env_name}.pth')
        torch.save({
            'actor': self.actor.state_dict(),
            'critic': self.critic.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            # 'step': current_step,
        }, f'./results/models-trained/PPO_multi_{env_name}.pth')

        with open('results/LSTM_rewards.json', 'w') as f:
            json.dump([float(r) for r in memory.rewards], f)
        with open('results/LSTM_running_rewards.json', 'w') as f:
            json.dump([float(r) for r in memory.running_rewards], f)
        memory_actions_list = [a.tolist() for a in memory.actions]
        with open('results/LSTM_states.json', 'w') as f:
            json.dump(to_serializable(memory.states), f)
        with open('results/LSTM_actions.json', 'w') as f:
            json.dump(memory_actions_list, f)


# ==== train class ====
class TrainLSTM:
    def __init__(self, env, state_dim, initial_oscelement, action_dim=3, action_std=0.08, lr=0.0001, gamma=0.99,K_epochs=100, batch_size=100,
                 clip_eps=0.2,max_episodes=500, max_timesteps=3000, update_timestep=2000, solved_reward=300000, log_interval=20,
                 PROPAGATION_STEP=0.001, betas=(0.95, 0.999), load_path=None):

        self.env = env
        self.initial_oscelement = initial_oscelement
        self.PROPAGATION_STEP = PROPAGATION_STEP
        self.max_episodes = max_episodes
        self.max_timesteps = max_timesteps
        self.update_timestep = update_timestep
        self.solved_reward = solved_reward
        self.log_interval = log_interval
        self.collision_distance = 5
        self.dV_limit = 0.17 / 750 * PROPAGATION_STEP * 86400 / 3 ** 0.5
        self.ddV_limit = self.dV_limit * 0.1
        self.action_std = action_std * self.dV_limit

        self.memory = Memory()
        self.agent = LSTMAgent(self.action_std, self.dV_limit, lr=lr, betas=betas, gamma=gamma, clip_eps=clip_eps,
                              K_epochs=K_epochs, batch_size=batch_size)

        if load_path:
            checkpoint = torch.load(load_path)
            self.agent.actor.load_state_dict(checkpoint['actor'])
            self.agent.actor_old.load_state_dict(checkpoint['actor'])
            self.agent.critic.load_state_dict(checkpoint['critic'])
            # self.agent.optimizer.load_state_dict(checkpoint['optimizer'])
            self.agent.actor.train()
            self.agent.critic.train()
            # step = checkpoint['step']
            # self.agent.actor.load_state_dict(torch.load(load_path))
            # self.agent.actor_old.load_state_dict(torch.load(load_path))
            # self.agent.actor.train()

    def reconstruct_env(self, collision_time):
        position = self.env.coords_by_epoch(collision_time)[0][0:3]
        env_start = collision_time - 0.04
        env_end = collision_time + 0.01
        self.env = generate_collisions(position, collision_time,  env_start, env_end, quantity=5, stddev=100, distance=5)
        self.initial_oscelement = self.env.protected.osculating_elements(pk.epoch(self.env.init_params["start_time"]))
        self.env.init_params = dict(protected=self.env.protected, debris=self.env.debris,
                                    start_time=env_start, end_time=env_end)

    def run(self, random_seed=None, env_name='debris_collision', reconstruct_eachturn=False):
        if random_seed:
            torch.manual_seed(random_seed)
            np.random.seed(random_seed)

        time_step = 0
        running_reward = 0
        avg_length = 0
        for i_episode in trange(self.max_episodes, desc="Training PPO Episodes"):
            time_tick = 0
            action = [0, 0, 0]
            min_distance_pre = [self.collision_distance] * len(self.env.debris)
            action_history = [
                torch.zeros(3, device=device) for _ in range(3)
            ]

            if reconstruct_eachturn and i_episode % 5==0:
                # print("reconstruct_each 2 turn")
                self.reconstruct_env(8901)
            else:
                self.env.reset()

            simulator_outer = Simulator(self.env, self.initial_oscelement, step=self.PROPAGATION_STEP,
                                        min_distance=min_distance_pre)

            for t in range(self.max_timesteps):
                time_step += 1
                env_outer_epoch = self.env.state['epoch']
                state_outer = np.array(self.env.get_state()['coord'], dtype=float)
                collision_now = calculate_distances(simulator_outer.env.get_state()['coord'])

                env_tmp = self.env.copy()
                simulator = Simulator(env_tmp, self.initial_oscelement, step=self.PROPAGATION_STEP,
                                      time_now=(time_tick * self.PROPAGATION_STEP + self.env.init_params["start_time"]),
                                      min_distance=min_distance_pre)

                state, reward, action_gpu, action_logprob, done_prob, done_label, Fail = simulator.run(self.agent.actor_old,
                                    env_outer_epoch, state_outer,self.dV_limit, action, action_history, collision_now=collision_now)
                action = action_gpu.detach().cpu().numpy().flatten()

                simulator_outer.curr_time += self.PROPAGATION_STEP
                propagate(simulator_outer.env, simulator_outer.curr_time)
                if simulator_outer.curr_time >= simulator_outer.end_time:
                    reward += 100
                elif done_prob.detach().cpu() < 0.5:
                    simulator_outer.doact(action)

                self.memory.store(state, reward, action_gpu, action_logprob, done_prob, done_label)
                action_history.pop(0)
                action_history.append(action_gpu.detach())

                for i in range(len(env_tmp.debris)):
                    min_distance_pre[i] = env_tmp.debris[i].min_distance

                del env_tmp

                if time_step % self.update_timestep == 0:
                    loss_info = self.agent.update(self.memory, env_name)
                    print(f"\n[Episode {i_episode}] Losses: {loss_info}")
                    if self.agent.action_std > 0.001 * self.dV_limit:
                        self.agent.action_std *= 0.99
                    time_step = 0

                running_reward += reward
                time_tick += 1

                if simulator_outer.curr_time >= simulator_outer.end_time or Fail:
                    break

            avg_length += t

            if running_reward > (self.log_interval * self.solved_reward):
                print("########## Solved! ##########")
                torch.save(self.agent.actor.state_dict(), f'./results/models-trained/PPO_solved{env_name}.pth')
                break

            if i_episode % self.log_interval == 0 and i_episode != 0:
                avg_length = int(avg_length / self.log_interval)
                running_reward /= self.log_interval
                self.memory.running_rewards.append(running_reward)

                print(f'\nEpisode {i_episode} \t Avg Length: {avg_length} \t Avg Reward: {running_reward}')
                running_reward = 0
                avg_length = 0

        plot_rewards(self.memory.rewards, 'rewards')
        plot_rewards(self.memory.running_rewards, 'running rewards')
        print("FINISHED")
