import numpy as np
from typing import Sequence
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from tqdm import tqdm

from opelab.core.baseline import Baseline
from opelab.core.data import DataType, to_numpy


class MLP(nn.Module):
    """
    A simple Multi-Layer Perceptron in PyTorch.
    """
    def __init__(
        self,
        input_dim: int,
        hidden_dims: Sequence[int] = (128, 128),
        activation: nn.Module = nn.ReLU(),
        output_dim: int = 1,
        output_activation: nn.Module = nn.Identity(),
    ):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev_dim, h))
            layers.append(activation)
            prev_dim = h

        layers.append(nn.Linear(prev_dim, output_dim))
        layers.append(output_activation)

        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)


class BestDiceV2(Baseline):
    """
    A PyTorch-based version of BestDiceV2 mirroring the JAX logic.
    """
    def __init__(
        self,
        alpha_zeta: float = 0.0,
        alpha_q: float = 1.0,
        lr: float = 0.00003,
        layers: Sequence[int] = (256, 256),
        epochs: int = 500,
        iter_per_epoch: int = 500,
        batch_size: int = 2048,
        seed: int = 0,
        weight_decay: float = 1e-7,
        verbose: bool = True,
    ):
        super().__init__()
        self.alpha_zeta = alpha_zeta
        self.alpha_q = alpha_q
        self.lr = lr
        self.layers = layers
        self.epochs = epochs
        self.iter_per_epoch = iter_per_epoch
        self.batch_size = batch_size
        self.seed = seed
        self.weight_decay = weight_decay
        self.verbose = verbose

        # Torch reproducibility
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)

        # We will build Q and Zeta inside evaluate() once we know
        # state_dim + action_dim from the data.
        self.Q = None
        self.Zeta = None
        self.q_optimizer = None
        self.zeta_optimizer = None


    def _build_models(self, state_dim: int, action_dim: int):

        input_dim = state_dim + action_dim

        self.Q = MLP(
            input_dim=input_dim,
            hidden_dims=self.layers,
            activation=nn.ReLU(),
            output_dim=1,
            output_activation=nn.Identity()
        )

        self.Zeta = MLP(
            input_dim=input_dim,
            hidden_dims=self.layers,
            activation=nn.ReLU(),
            output_dim=1,
            output_activation=nn.Identity()
        )
        
        self.Q.apply(init_weights)
        self.Zeta.apply(init_weights)

        self.q_optimizer = optim.AdamW(
            self.Q.parameters(),
            lr=self.lr,
        )
        self.zeta_optimizer = optim.AdamW(
            self.Zeta.parameters(),
            lr=self.lr,
        )

    def forward_q(self, s: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
        Q = self.Q(torch.cat([s, a], dim=-1))
        return torch.tanh(Q) * 100

    def forward_zeta(self, s: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
        raw_out = self.Zeta(torch.cat([s, a], dim=-1))
        return F.softplus(raw_out) 

    def _compute_losses(
        self,
        s0: torch.Tensor,
        a0: torch.Tensor,
        s: torch.Tensor,
        a: torch.Tensor,
        r: torch.Tensor,
        s2: torch.Tensor,
        a2: torch.Tensor,
        gamma: float,
        lambda_val: float
    ):

        q = self.forward_q(s, a)
        q2 = self.forward_q(s2, a2)
        q0 = self.forward_q(s0, a0)
        zeta = self.forward_zeta(s, a)

        initial_term = (1.0 - gamma) * q0.mean()
        bellman_term =  (zeta * (r + gamma * q2 - q)).mean()
        # print(f'inner bellman: {r  + gamma * q2 - q}')
        lambda_term = lambda_val * (1.0 - zeta.mean())  
        zeta_reg_term = self.alpha_zeta * (zeta ** 2).mean()
        q_reg_term = self.alpha_q * (q ** 2).mean()

        q_loss = initial_term + bellman_term + q_reg_term
        zeta_loss = bellman_term - zeta_reg_term + lambda_term
        return q_loss, zeta_loss

    def evaluate(self, data, target, behavior, gamma=1.0, reward_estimator=None):
        # data = data[:50]

        init_states = np.stack([tau['states'][0].reshape((-1,)) for tau in data], axis=0)
        _, states_un, actions, _, next_states_un, rewards, _ = to_numpy(data, target, behavior)
        rewards = rewards.reshape((-1, 1))
        
        min_reward = 0
        max_reward = 1
        if np.min(rewards) < 0:
            min_reward = np.min(rewards)
            rewards = rewards - min_reward
            print("Rewards are negative, shifting them to be non-negative by: ", min_reward)
            max_reward = np.max(rewards)
            print("Max reward: ", max_reward)
            max_reward = 1
            rewards = rewards / max_reward
            

        state_dim = states_un.shape[1]
        action_dim = actions.shape[1]
        print(f"State dim: {state_dim}, action dim: {action_dim}")
        self._build_models(state_dim, action_dim)

        device = target.device
        self.Q.to(device)
        self.Zeta.to(device)

        init_states_t = torch.tensor(init_states, dtype=torch.float32, device=device)
        states_t = torch.tensor(states_un, dtype=torch.float32, device=device)
        actions_t = torch.tensor(actions, dtype=torch.float32, device=device)
        rewards_t = torch.tensor(rewards, dtype=torch.float32, device=device)
        next_states_t = torch.tensor(next_states_un, dtype=torch.float32, device=device)
        
        #print all the shapes
        print(f"init_states_t: {init_states_t.shape}")
        print(f"states_t: {states_t.shape}")
        print(f"actions_t: {actions_t.shape}")
        print(f"rewards_t: {rewards_t.shape}")
        print(f"next_states_t: {next_states_t.shape}")
        

        lambda_val = 0.0  

        best_value = float('inf')
        best_zeta_dict = None

        epoch_range = range(self.epochs)
        if self.verbose:
            epoch_range = tqdm(epoch_range, desc="Epochs")

        for epoch in epoch_range:
            mean_q_loss = 0.0
            mean_zeta_loss = 0.0

            inner_iter = range(self.iter_per_epoch)
            if self.verbose:
                inner_iter = tqdm(inner_iter, desc=f"Epoch {epoch}", leave=False)

            for _ in inner_iter:
                idx0 = np.random.randint(0, init_states_t.shape[0], size=self.batch_size)
                idx = np.random.randint(0, states_t.shape[0], size=self.batch_size)

                s0 = init_states_t[idx0]
                s = states_t[idx]
                a = actions_t[idx]
                r = rewards_t[idx]
                s2 = next_states_t[idx]

                a0_np = target.sample(s0.cpu().numpy())
                a0 = torch.tensor(a0_np, dtype=torch.float32, device=device)
                a2_np = target.sample(s2.cpu().numpy())
                a2 = torch.tensor(a2_np, dtype=torch.float32, device=device)
                
                if a0.dim() == 1:
                    a0 = a0.unsqueeze(1)
                if a2.dim() == 1:
                    a2 = a2.unsqueeze(1)
                

                self.q_optimizer.zero_grad()
                q_loss, zeta_loss = self._compute_losses(s0, a0, s, a, r, s2, a2, gamma, lambda_val)
                q_loss.backward(retain_graph=False)
                self.q_optimizer.step()
                
                zeta_t = self.forward_zeta(s, a).mean().item()
                
                if zeta_t < 1:
                    lambda_val -= 0.001 * (1.0 - zeta_t)
                else:
                    lambda_val -= 0.0001 * (1.0 - zeta_t)

                self.zeta_optimizer.zero_grad()
                q_loss, zeta_loss = self._compute_losses(s0, a0, s, a, r, s2, a2, gamma, lambda_val)
                (-zeta_loss).backward()
                self.zeta_optimizer.step()
                

                mean_q_loss += q_loss.item() / float(self.iter_per_epoch)
                mean_zeta_loss += zeta_loss.item() / float(self.iter_per_epoch)

                if self.verbose:
                    inner_iter.set_postfix({
                        'q_loss': f"{q_loss.item():.4f}",
                        'zeta_loss': f"{zeta_loss.item():.4f}",
                        'lambda': f"{lambda_val:.4f}"
                    })

            with torch.no_grad():
                zeta_vals = self.forward_zeta(states_t, actions_t)
                val_est = (zeta_vals * ((rewards_t * max_reward + min_reward))).mean().item()
                if gamma < 1.0:
                    val_est /= (1.0 - gamma)

                zeta_mean = zeta_vals.mean().item()
                zeta_max = zeta_vals.max().item()

            if abs(val_est) < abs(best_value):
                best_value = val_est
                best_zeta_dict = {k: v.clone() for k, v in self.Zeta.state_dict().items()}

            if self.verbose:
                if isinstance(epoch_range, tqdm):
                    epoch_range.set_postfix({
                        'mean_q_loss': f"{mean_q_loss:.4f}",
                        'mean_zeta_loss': f"{mean_zeta_loss:.4f}",
                        'lambda': f"{lambda_val:.4f}",
                        'value': f"{val_est:.4f}",
                        'zeta_mean': f"{zeta_mean:.4f}",
                        'zeta_max': f"{zeta_max:.4f}"
                    })
                print(
                    f"Epoch {epoch}, value: {val_est:.4f}, "
                    f"zeta_mean: {zeta_mean:.4f}, zeta_max: {zeta_max:.4f}"
                )


        if best_zeta_dict is not None:
            self.Zeta.load_state_dict(best_zeta_dict)

        with torch.no_grad():
            zeta_vals = self.forward_zeta(states_t, actions_t)
            final_est = (zeta_vals * (rewards_t * max_reward + min_reward)).mean().item()
            if gamma < 1.0:
                final_est /= (1.0 - gamma)

        return float(final_est)
