import json
import os
import random
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from config import Args
from ipl_iql import IPL_IQL
from collections import deque
from policy import Actor, Critic
from tensordict import TensorDict
from torch.utils.tensorboard import SummaryWriter
from concurrent.futures import ThreadPoolExecutor as Pool

from utils import create_context


class Trajectory:

    def __init__(self, data):
        traj_len = len(data) - 1
        tmp_transition = data[0]
        ob_dim = len(tmp_transition[0][0])
        st_dim = len(tmp_transition[1])
        ac_dim = len(tmp_transition[2][0])
        n_agents = len(tmp_transition[0])
        self.obs = torch.zeros((traj_len + 1, n_agents, ob_dim), dtype=torch.float32)
        self.states = torch.zeros((traj_len + 1, st_dim), dtype=torch.float32)
        self.avails = torch.zeros((traj_len + 1, n_agents, ac_dim), dtype=torch.bool)
        self.actions = torch.zeros((traj_len, n_agents), dtype=torch.int64)
        self.rewards = torch.zeros((traj_len,), dtype=torch.float32)
        self.dones = torch.zeros((traj_len,), dtype=torch.bool)
        self.traj_len = traj_len

        for t, (obs, state, avails, actions, reward, done) in enumerate(data):
            self.obs[t] = torch.tensor(np.array(obs), dtype=torch.float32)
            self.states[t] = torch.tensor(np.array(state), dtype=torch.float32)
            self.avails[t] = torch.tensor(np.array(avails), dtype=torch.bool)
            if reward is not None:
                self.actions[t] = torch.tensor(np.array(actions), dtype=torch.int64)
                self.rewards[t] = reward
                self.dones[t] = done
    
    def __repr__(self):
        return str({
            "obs": self.obs.shape,
            "states": self.states.shape,
            "avails": self.avails.shape,
            "actions": self.actions.shape,
            "rewards": self.rewards.shape,
            "dones": self.dones.shape
        })


class LLMAgent:

    def __init__(self, model_id: str, device="cuda"):
        from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache

        self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id, 
            trust_remote_code=True,
            dtype=torch.bfloat16,
            device_map=device,
        )
        self.past_key_values = DynamicCache(config=self.model.config)

        self.count = 0

    def _query(self, prompt: str):
        input_ids = self.tokenizer([prompt], return_tensors="pt").input_ids.to(self.model.device)
        with torch.inference_mode():
            outputs = self.model.forward(input_ids=input_ids, past_key_values=self.past_key_values, use_cache=True)
            self.past_key_values = outputs.past_key_values
        output_ids = torch.argmax(outputs.logits[0, -1], dim=-1)
        input_ids = output_ids.unsqueeze(0).unsqueeze(0)
        text = self.tokenizer.decode(output_ids)
        return text
    
    def get_response(self, context: str, prompt: str):
        from transformers import DynamicCache

        query = f"<start_of_turn>user\n{context}\n{prompt}<end_of_turn>\n<start_of_turn>model\nAnswer:"
        self.past_key_values = DynamicCache(config=self.model.config)
        text_1 = self._query(query)
        text_2 = self._query(text_1)
        response = text_1 + text_2

        self.count += 1
        if self.count < 100:
            s = query + response
            s = " ".join(s.split())
            print(f"#{self.count}: ...{s[-50:]}")
        return response


class Trainer:

    def __init__(self, actor: Actor, critic: Critic, reward_net: IPL_IQL, logdir, config: Args, device="cuda"):
        self.actor = actor.to(device)
        self.critic = critic.to(device)
        self.reward_net = reward_net.to(device)
        self.config = config
        self.device = device
        self.actor_optim = torch.optim.Adam(actor.parameters(), lr=5e-4)
        self.critic_optim = torch.optim.Adam(critic.parameters(), lr=5e-4)
        self.eval_parameters = list(reward_net.eval_mix_net.parameters()) + list(reward_net.eval_Q_net.parameters())
        self.reward_optim = torch.optim.Adam(self.eval_parameters, lr=5e-4)
        self.n_minibatch = 1
        self.noptepochs = 5
        self.clip_ratio = 0.2
        self.entropy_coef = 0.01
        self.max_grad_norm = 0.5
        self.writer = SummaryWriter(logdir)

        self.preference_size = 4096 * 2
        self.preference_buffers: list[tuple[Trajectory, Trajectory]] = []

        self.use_llm = True
        self.llm_model = LLMAgent("google/gemma-3-4b-it", device=device) if self.use_llm else None
    
    def compute_values(self, states, next_states, dones, critic_rnn_state):
        self.critic.eval()
        with torch.no_grad():
            values, _ = self.critic.forward_all(states, critic_rnn_state, dones)
            all_states = torch.cat((states[:1], next_states))
            all_dones = torch.cat((dones, torch.zeros_like(dones[-1:], device=self.device)))
            all_values, _ = self.critic.forward_all(all_states, critic_rnn_state, all_dones)
            next_values = all_values[1:]
        self.critic.train()
        return values, next_values

    def compute_advantages(self, rewards, values, next_values, dones, gamma=0.99, lam=0.95):
        n_steps = rewards.shape[0]
        advantages = torch.zeros_like(rewards, device=self.device)
        masks = 1.0 - dones.float()
        delta = rewards + gamma * masks * next_values - values
        discounts = gamma * lam * masks
        last_gae_lam = 0
        for step in reversed(range(n_steps)):
            advantages[step] = last_gae_lam = delta[step] + discounts[step] * last_gae_lam
        return advantages

    def train_actor(self, obs, avails, actions, old_log_probs, dones, advantages, actor_rnn_state):
        self.actor_optim.zero_grad()
        pg_loss, entropy_loss, actor_rnn_state = self.actor.compute_loss(obs, avails, actions, old_log_probs, dones, advantages, actor_rnn_state)
        nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
        self.actor_optim.step()
        return pg_loss, entropy_loss, actor_rnn_state

    def train_critic(self, states, critic_rnn_state, dones, value_preds, returns):
        self.critic_optim.zero_grad()
        vf_loss, critic_rnn_state = self.critic.compute_loss(states, critic_rnn_state, dones, value_preds, returns)
        nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
        self.critic_optim.step()
        return vf_loss, critic_rnn_state
    
    def update_actor(self, buffer: TensorDict, actor_rnn_state):
        self.actor.train()

        mb_obs = buffer["obs"]
        mb_avails = buffer["avails"]
        mb_actions = buffer["actions"]
        mb_log_probs = buffer["log_probs"]
        mb_dones = buffer["dones"]
        mb_advantages = buffer["advantages"]

        loss_vals = []
        for _ in range(self.noptepochs):
            pg_loss, entropy_loss, new_actor_rnn_state = self.train_actor(mb_obs, mb_avails, mb_actions, mb_log_probs, mb_dones, mb_advantages, actor_rnn_state)
            loss_vals.append((pg_loss, entropy_loss))
        pg_loss, entropy_loss = np.mean(loss_vals, axis=0)
        return pg_loss, entropy_loss, new_actor_rnn_state
    
    def update_critic(self, buffer: TensorDict, critic_rnn_state):
        self.critic.train()

        mb_states = buffer["states"]
        mb_dones = buffer["dones"]
        mb_values = buffer["values"]
        mb_returns = buffer["returns"]

        loss_vals = []
        for _ in range(self.noptepochs):
            vf_loss, new_critic_rnn_state = self.train_critic(mb_states, critic_rnn_state, mb_dones, mb_values, mb_returns)
            loss_vals.append(vf_loss)
        vf_loss = np.mean(loss_vals)
        return vf_loss, new_critic_rnn_state

    def get_trajectory_info(self, traj: Trajectory):
        n_steps = traj.traj_len
        last_state = traj.states[n_steps - 1]
        reward = traj.rewards.sum().item()

        ally_state_pos = self.config.n_agents * self.config.nf_al
        enemy_state_pos = ally_state_pos + self.config.n_enemies * self.config.nf_en

        ally_state = last_state[:ally_state_pos].view(self.config.n_agents, self.config.nf_al)
        enemy_state = last_state[ally_state_pos:enemy_state_pos].view(self.config.n_enemies, self.config.nf_en)

        ally_healths = ally_state[..., 0].numpy()
        enemy_healths = enemy_state[..., 0].numpy()

        return ally_healths, enemy_healths, reward, n_steps
    
    def get_preference_label(self, traj_x: Trajectory, traj_y: Trajectory):
        info_x = self.get_trajectory_info(traj_x)
        info_y = self.get_trajectory_info(traj_y)

        if self.use_llm:
            context, prompt = create_context(self.config.env_name, info_x, info_y)
            response = self.llm_model.get_response(context, prompt)
            if "#1" in response:
                label = 1
            elif "#2" in response:
                label = 2
            else:
                label = None
        else:
            if info_x[2] > info_y[2]:
                label = 1
            elif info_x[2] < info_y[2]:
                label = 2
            else:
                label = None
        return label

    def add_preference(self, infos, n_samples=128):
        trajectories = []
        for list_info in infos:
            for info in list_info:
                if "trajectory" not in info:
                    continue
                trajectories.append(info["trajectory"])
        
        selected = set()
        ids = list(range(len(trajectories)))

        for _ in range(n_samples):
            id_1, id_2 = random.sample(ids, 2)
            if (id_1, id_2) in selected or (id_2, id_1) in selected:
                continue
            selected.add((id_1, id_2))
            traj_1 = Trajectory(trajectories[id_1])
            traj_2 = Trajectory(trajectories[id_2])

            label = self.get_preference_label(traj_1, traj_2)
            if label is None:
                continue

            if label == 1:
                self.preference_buffers.append((traj_1, traj_2))
            else:
                self.preference_buffers.append((traj_2, traj_1))
        
        self.preference_buffers = self.preference_buffers[-self.preference_size:]

    def train_reward_net(self, sample_size):
        preference_buffers = random.sample(self.preference_buffers, min(len(self.preference_buffers), sample_size))
        
        batch_size = 2 * len(preference_buffers)
        max_len = max([max(traj_1.traj_len, traj_2.traj_len) for traj_1, traj_2 in preference_buffers])
        ob_dim = self.reward_net.ob_dim
        st_dim = self.reward_net.st_dim
        ac_dim = self.reward_net.ac_dim
        n_agents = self.reward_net.n_agents

        mb_obs = torch.zeros((batch_size, max_len + 1, n_agents, ob_dim), dtype=torch.float32)
        mb_states = torch.zeros((batch_size, max_len + 1, st_dim), dtype=torch.float32)
        mb_avails = torch.zeros((batch_size, max_len + 1, n_agents, ac_dim), dtype=torch.bool)
        mb_actions = torch.zeros((batch_size, max_len, n_agents), dtype=torch.int64)
        mb_rewards = torch.zeros((batch_size, max_len), dtype=torch.float32)
        mb_dones = torch.zeros((batch_size, max_len), dtype=torch.bool)
        mb_actives = torch.zeros((batch_size, max_len), dtype=torch.bool)

        for b, (traj_1, traj_2) in enumerate(preference_buffers):
            traj_len = traj_1.traj_len
            mb_obs[2*b, :traj_len + 1] = traj_1.obs
            mb_states[2*b, :traj_len + 1] = traj_1.states
            mb_avails[2*b, :traj_len + 1] = traj_1.avails
            mb_actions[2*b, :traj_len] = traj_1.actions
            mb_rewards[2*b, :traj_len] = traj_1.rewards
            mb_dones[2*b, :traj_len] = traj_1.dones
            mb_actives[2*b, :traj_len] = True

            traj_len = traj_2.traj_len
            mb_obs[2*b+1, :traj_len + 1] = traj_2.obs
            mb_states[2*b+1, :traj_len + 1] = traj_2.states
            mb_avails[2*b+1, :traj_len + 1] = traj_2.avails
            mb_actions[2*b+1, :traj_len] = traj_2.actions
            mb_rewards[2*b+1, :traj_len] = traj_2.rewards
            mb_dones[2*b+1, :traj_len] = traj_2.dones
            mb_actives[2*b+1, :traj_len] = True

        self.reward_optim.zero_grad()
        global_rewards = self.reward_net.compute_all_rewards(mb_obs, mb_states, mb_avails, mb_actions, mb_dones, mb_actives, gamma=0.99)

        traj_1_rewards = global_rewards[0::2]
        traj_2_rewards = global_rewards[1::2]

        logits = traj_1_rewards - traj_2_rewards
        ipl_loss = F.softplus(-logits).mean()
        ipl_loss.backward()
        self.reward_optim.step()
        return ipl_loss.item()

    def update_reward_net(self, infos, sample_size=32):
        self.add_preference(infos)
        if len(self.preference_buffers) < 2:
            return 0.0
        
        self.reward_net.train()

        loss_vals = []
        for _ in range(self.noptepochs):
            ipl_loss = self.train_reward_net(sample_size)
            loss_vals.append(ipl_loss)
        ipl_loss = np.mean(loss_vals)
        return ipl_loss
    
    def predict_rewards(self, obs, states, avails, actions, next_obs, next_states, next_avails, dones):
        self.reward_net.eval()
        with torch.no_grad():
            rewards_local, rewards_global, weights = self.reward_net.predict_rewards(obs, states, avails, actions, next_obs, next_states, next_avails, dones, gamma=0.99)
            rewards_local = rewards_local.sigmoid() * 2.0 - 1.0  # Scale to [-1, 1]
            rewards_global = rewards_global.sigmoid() * 2.0 - 1.0  # Scale to [-1, 1]

            # weights = weights / (weights.mean(-1, keepdim=True) + 1e-8)
            # weights = weights.clamp(0.5, 2.0)

            weights = 1 + (weights - weights.mean()) / (weights.std() + 1e-8)
            weights = weights.clamp(0.8, 1.2)
            
        return rewards_local, rewards_global, weights

    def update(self, data, actor_rnn_state, critic_rnn_state):
        obs, states, avails, next_obs, next_states, next_avails, actions, log_probs, rewards, dones, infos = data

        obs = torch.tensor(obs, dtype=torch.float32).to(self.device)
        states = torch.tensor(states, dtype=torch.float32).to(self.device)
        avails = torch.tensor(avails, dtype=torch.bool).to(self.device)
        next_obs = torch.tensor(next_obs, dtype=torch.float32).to(self.device)
        next_states = torch.tensor(next_states, dtype=torch.float32).to(self.device)
        next_avails = torch.tensor(next_avails, dtype=torch.float32).to(self.device)
        actions = torch.tensor(actions, dtype=torch.int64).to(self.device)
        log_probs = torch.tensor(log_probs, dtype=torch.float32).to(self.device)
        # rewards = torch.tensor(rewards, dtype=torch.float32).to(self.device)
        dones = torch.tensor(dones, dtype=torch.bool).to(self.device)

        ipl_loss = self.update_reward_net(infos)
        rewards_local, rewards_global, weights = self.predict_rewards(obs, states, avails, actions, next_obs, next_states, next_avails, dones)

        values, next_values = self.compute_values(states, next_states, dones, critic_rnn_state)
        advantages = self.compute_advantages(rewards_global, values, next_values, dones)
        
        returns = advantages + values

        # A_global = R_global + y * V' - V
        # A_global = f(w) * A_local + b
        # R_global = f(w) * R_local + b
        # A_local = R_local + (y * V' - V).unsqueeze(-1) / f(w)
        

        # advantages = advantages.unsqueeze(-1)
        # advantages = (advantages - rewards_global).unsqueeze(-1) / weights + rewards_local
        advantages = advantages.unsqueeze(-1) / weights
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        buffer = {
            "obs": obs,
            "states": states,
            "avails": avails,
            "next_states": next_states,
            "actions": actions,
            "log_probs": log_probs,
            "rewards": rewards,
            "dones": dones,
            "values": values,
            "advantages": advantages,
            "returns": returns
        }

        # with Pool() as p:
        #     actor_task = p.submit(self.update_actor, buffer, actor_rnn_state)
        #     critic_task = p.submit(self.update_critic, buffer, critic_rnn_state)
        #     pg_loss, entropy_loss, actor_rnn_state = actor_task.result()
        #     vf_loss, critic_rnn_state = critic_task.result()
        pg_loss, entropy_loss, actor_rnn_state = self.update_actor(buffer, actor_rnn_state)
        vf_loss, critic_rnn_state = self.update_critic(buffer, critic_rnn_state)

        infos = [i for info in infos for i in info if len(i) > 0]
        dead_allies = np.mean([info["dead_allies"] for info in infos])
        dead_enemies = np.mean([info["dead_enemies"] for info in infos])
        winrates = np.mean([info["won"] for info in infos])
        info = (dead_allies, dead_enemies, winrates)

        return pg_loss, entropy_loss, vf_loss, ipl_loss, actor_rnn_state, critic_rnn_state, info