import torch
import math

import numpy as np
import torch.nn as nn

from sgcrl.utils.imports import instantiate_class, get_class
from sgcrl.gym_helpers import Bot

class HIQL_overlay(Bot):
    def __init__(self, vf1, vf2, policy_low, policy_high):
        super().__init__()

        self.vf1 = instantiate_class(vf1)
        self.vf2 = instantiate_class(vf2)
        self.policy_low = instantiate_class(policy_low)
        self.policy_high = instantiate_class(policy_high)

        self.target_vf1 = instantiate_class(vf1)
        self.target_vf2 = instantiate_class(vf2)

        self.target_vf1.load_state_dict(self.vf1.state_dict())
        self.target_vf2.load_state_dict(self.vf2.state_dict())

    def to(self, device):
        self.vf1 = self.vf1.to(device)
        self.vf2 = self.vf2.to(device)
        self.target_vf1 = self.target_vf1.to(device)
        self.target_vf2 = self.target_vf2.to(device)
        self.policy_low = self.policy_low.to(device)
        self.policy_high = self.policy_high.to(device)
        return self

    def cpu(self):
        self.to('cpu')
        return self

    def train(self):
        self.vf1.train()
        self.vf2.train()
        self.policy_low.train()
        self.policy_high.train()

    def eval(self):
        self.vf1.eval()
        self.vf2.eval()
        self.policy_low.eval()
        self.policy_high.eval()

    def compile(self):
        self.vf1 = torch.compile(self.vf1)
        self.vf2 = torch.compile(self.vf2)
        self.policy_low = torch.compile(self.policy_low)
        self.policy_high = torch.compile(self.policy_high)
    
    def reset(self,seed):
        pass

    def parameters(self):
        return {
            'vf1': self.vf1.parameters(), 
            'vf2': self.vf2.parameters(), 
            'policy_low': self.policy_low.parameters(),
            'policy_high': self.policy_high.parameters(),
        }
    
    @torch.no_grad()
    def _action(self, frame, **kwargs):
        """
        Bot._action always takes up as:
        Input:
         - frame: dict that has at least: 'observation': torch.Tensor of size (obs_dim,), 'goal': torch.Tensor of size (goal_dim,)
        
        Output:
         - action: np.ndarray of size (action_dim,)
        
        """
        action = self.policy_high._action(frame, **kwargs)
        return action
        # adding back obs to goal since pi_low learns pi(a|w,s) and pi_high learns pi(w-s|s,g)
        low_goal = waypoint + obs_goal[1:3]
        
        return self.policy_low._action(torch.cat((low_goal, obs_goal[1:]), dim=0), **kwargs)

class HIQL_optimizer_overlay():
    def __init__(self, parameters, cfg):
        super().__init__()

        self.vf1 = get_class(cfg.optimizer_vf1)(parameters['vf1'], cfg.optimizer_vf1.lr)
        self.vf2 = get_class(cfg.optimizer_vf2)(parameters['vf2'], cfg.optimizer_vf2.lr)
        self.policy_low = get_class(cfg.optimizer_policy_low)(parameters['policy_low'], cfg.optimizer_policy_low.lr)
        self.policy_high = get_class(cfg.optimizer_policy_high)(parameters['policy_high'], cfg.optimizer_policy_high.lr)

def apply_variance_scaling_init(net, scale=1.0):
    fan = (net.weight.size(-2) + net.weight.size(-1)) / 2
    init_w = math.sqrt(scale / fan)
    net.weight.data.uniform_(-init_w, init_w)
    net.bias.data.fill_(0)

def custom_init_mlp(
    sizes: list[int],
    activation: nn.Module,
    output_activation: nn.Module = nn.Identity,
    output_init_scaling: float = 1.0,
    dropout: float = None,
    layer_norm: bool = False,
):
    """Create a Multilayer Perceptron in one call with custom weight initialization

    Args:
        sizes (list of int): array of layers sizes
        activation (activation function): The activation function on internal layers
        output_activation (activation function): The activation function of the output layer. Defaults is`nn.Identity`.
        output_init_scaling: scaling for the VarianceScaling init

    Returns:
        torch.nn.Module: the MLP
    """
    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        if j < len(sizes) - 2:
            fc = nn.Linear(sizes[j], sizes[j + 1])
            apply_variance_scaling_init(fc, scale=1.0)

            if layer_norm and j > 0:
                layers += [nn.LayerNorm(sizes[j], eps=1e-6), fc, act()]
            else:
                layers += [fc, act()]
            if dropout is not None:
                layers.append(nn.Dropout(dropout))
        else:
            fc = nn.Linear(sizes[j], sizes[j + 1])
            apply_variance_scaling_init(fc, scale=output_init_scaling)
            if layer_norm:
                layers += [nn.LayerNorm(sizes[j], eps=1e-6), fc, act()]
            else:
                layers += [fc, act()]
    m = nn.Sequential(*layers)
    return m

class GCMLPValue(nn.Module):
    def __init__(self, input_dim, hidden_sizes, layer_norm, embedding_dim=0, num_embedding=0):
        nn.Module.__init__(self)

        if embedding_dim != 0 and num_embedding != 0:
            self.embedding = torch.nn.Embedding(num_embeddings=num_embedding, embedding_dim=embedding_dim)
            input_dim += embedding_dim - 1
        else:
            self.embedding = None

        self.model = custom_init_mlp(
            sizes=[input_dim] + hidden_sizes + [1],
            activation=nn.GELU,
            layer_norm=layer_norm,
        )

    def forward(self, obs, goal):
        obs_goal = torch.cat([obs,goal],dim=-1)
        if self.embedding is not None:
            cluster = obs_goal[:, 0]
            obs_goal = torch.cat((self.embedding(cluster.long()), obs_goal[:, 1:]), dim=1)
        return self.model(obs_goal)

class GCMatrixValue(nn.Module):
    def __init__(self, n_tokens):
        nn.Module.__init__(self)

        # Define a matrix as a learnable parameter
        self.matrix = nn.Parameter(torch.randn(n_tokens, n_tokens))
    
    def forward(self, obs_token, goal_token):
        value = self.matrix[obs_token, goal_token]
        return value

class MLPDiscreteHighPolicy(nn.Module):
    def __init__(self, input_dim, hidden_sizes, layer_norm, codebook_size, embedding_dim=0, num_embedding=0):
        nn.Module.__init__(self)

        if embedding_dim != 0 and num_embedding != 0:
            self.embedding = torch.nn.Embedding(num_embeddings=num_embedding, embedding_dim=embedding_dim)
            input_dim += embedding_dim - 1
        else:
            self.embedding = None

        self.model = custom_init_mlp(
            sizes=[input_dim] + hidden_sizes + [codebook_size],
            activation=nn.GELU,
            layer_norm=layer_norm,
        )

    def forward(self, obs, goal):
        obs_goal = torch.cat([obs,goal],dim=-1)
        if self.embedding is not None:
            cluster = obs_goal[:, 0]
            obs_goal = torch.cat((self.embedding(cluster.long()), obs_goal[:, 1:]), dim=1)
        return self.model(obs_goal)
    
class MLPGaussianPolicy(nn.Module, Bot):
    def __init__(
        self,
        input_dim,
        hidden_sizes,
        action_dim,
        std=None,
        dropout=None,
        min_log_std=0,
        max_log_std=1,
        base_obs=None,
        to_numpy=True,
        activation=nn.Identity,
        embedding_dim=0,
        num_embedding=0,
    ):
        nn.Module.__init__(self)
        self.min_log_std = min_log_std
        self.max_log_std = max_log_std
        self.log_std = None
        self.std = std
        self.base_obs = torch.tensor(base_obs) if base_obs else base_obs
        self.to_numpy = to_numpy

        if embedding_dim != 0 and num_embedding != 0:
            self.embedding = torch.nn.Embedding(num_embeddings=num_embedding, embedding_dim=embedding_dim)
            input_dim += embedding_dim - 1
        else:
            self.embedding = None

        self.model = custom_init_mlp(
            sizes=[input_dim] + hidden_sizes + [action_dim],
            activation=nn.ReLU,
            output_activation=activation,
            output_init_scaling=0.01,
            dropout=dropout,
        )

        self.ghost_params = torch.nn.Parameter(torch.randn(()))

        if std is None:
            self.log_std_logits = nn.Parameter(torch.zeros(action_dim, requires_grad=True))
        else:
            self.log_std = torch.log(std).to(self.ghost_params.device)

        self.is_eval = False

    def forward(self, obs, goal, is_deterministic=True):
        obs_goal = torch.cat([goal,obs],dim=-1) # (for old models)
        # obs_goal = torch.cat([obs,goal],dim=-1)
        if self.embedding is not None:
            print(self.embedding.weight.sum())
            cluster = obs_goal[:, 0]
            obs_goal = torch.cat((self.embedding(cluster.long()), obs_goal[:, 1:]), dim=1)
        mean = self.model(obs_goal)
        if self.std is None:
            # switching to clipping as in HIQL rather than sigmoid norm as in IQL. Does not look like it changed perfs.
            log_std = torch.clip(self.log_std_logits, self.max_log_std, self.min_log_std)  # min is max in yaml
            std = torch.exp(log_std)
        else:
            std = self.std
            
        action_dist = torch.distributions.Independent(torch.distributions.Normal(mean, std), reinterpreted_batch_ndims=1)

        if is_deterministic:
            return mean, action_dist
        else:
            return action_dist.sample(), action_dist

    # mean, dist = pi.forward(
    #     torch.cat((batch["low_goal"], batch["obs"]), dim=1)
    # )
    # logpp = dist.log_prob(batch["action"])

    @torch.no_grad()
    def _action(self, frame, **kwargs):
        """
        Bot._action always takes up a frame which is a dict with at least: 
         - 'observation': torch.Tensor of size (obs_dim,)
         - 'goal': torch.Tensor of size (goal_dim,)
        """
        obs = frame['observation']
        goal = frame['goal']
        if (not self.is_eval) and "eval" in kwargs:
            self.eval()
            self.is_eval = True
            if torch.is_tensor(self.base_obs):
                print("using base obs")

        action, _ = self.forward(obs.unsqueeze(0), goal.unsqueeze(0), is_deterministic=not kwargs["stochastic"])

        if self.to_numpy:
            return np.clip(action.squeeze(0).numpy(), -1, 1)
        else:
            return action

    def reset(self, seed):
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

def _to_action(batch):
    mr = batch["action/move_right"].unsqueeze(-1)
    ml = batch["action/move_left"].unsqueeze(-1)
    mf = batch["action/move_forwards"].unsqueeze(-1)
    mb = batch["action/move_backwards"].unsqueeze(-1)
    move = torch.cat([mr, ml, mf, mb], dim=1)

    cos_r = batch["action/rotation"].cos().unsqueeze(-1)
    sin_r = batch["action/rotation"].sin().unsqueeze(-1)
    cos_sin_r = torch.cat([cos_r, sin_r], dim=1)
    return move, cos_sin_r

class MLPGaussianPolicyGodot(nn.Module, Bot):
    def __init__(
        self,
        goal_key,
        obs_dim=8, # 129 with raycasts
        goal_dim=2,
        # obs_keys,
        # goal_keys,
        hidden_sizes=[256, 256],
        action_dim=6,
        std=None,
        dropout=None,
        min_log_std=0,
        max_log_std=1,
        base_obs=None,
        to_numpy=True,
        flat=True,
        activation=nn.Identity,
        bounds=None
    ):
        nn.Module.__init__(self)
        self.goal_key = goal_key
        self.min_log_std = min_log_std
        self.max_log_std = max_log_std
        self.log_std = None
        self.std = std
        self.base_obs = torch.tensor(base_obs) if base_obs else base_obs
        self.to_numpy = to_numpy
        self.flat = flat

        # if obs_keys is not None:
        #     for k in obs_keys:
        #         if "raycasts" in k:
        #             self.raycasts = True
        # else:
        #     self.raycasts = False

        # self.bounds = bounds
        # if self.bounds is not None:
        #     self.bounds = dict(bounds)
        #     self.norm = NormEpisodeProcessor(self.bounds)


        self.model = custom_init_mlp(
            sizes=[obs_dim+goal_dim] + hidden_sizes + [action_dim],
            activation=nn.ReLU,
            output_activation=activation,
            output_init_scaling=0.01,
            dropout=dropout,
        )

        self.ghost_params = torch.nn.Parameter(torch.randn(()))

        if std is None:
            self.log_std_logits = nn.Parameter(
                torch.zeros(2, requires_grad=True)  # only the two rot dimensions use Gaussian distribution
            )
        else:
            self.log_std = torch.log(std).to(self.ghost_params.device)
        self.is_eval = False

    def forward(self, batch, is_deterministic=True):
        if self.goal_key in batch.keys():
            goal = batch[self.goal_key] # dim = 2 or 1 (index of token) or representation_dim
        else:
            goal = batch['sensor/absolute_goal_position']
        ball=batch["sensor/position"] # dim = 2
        if ball.shape[1] != 2:
            ball = ball[:, [0,2]] 
        if goal.shape[1] != 2:
            goal = goal[:, [0,2]]
        cos_rot=batch["sensor/rotation"].cos() # dim = 3
        sin_rot=batch["sensor/rotation"].sin() # dim = 3
        # print(batch["sensor/raycasts"].shape)
        # raycasts=batch["sensor/raycasts"].reshape(batch['sensor/raycasts'].size(0), -1)
        input = torch.cat([ball, cos_rot, sin_rot, goal], dim=1)
        out = self.model(input)
        move=out[:,:4]
        move=torch.sigmoid(move)
        cos_sin_rotation=out[:,4:6]
        cos_sin_rotation=torch.tanh(cos_sin_rotation)
        return (move,cos_sin_rotation)

    def log_prob(self, batch):
        move, rot = self(batch)
        target_move, target_rot = _to_action(batch)
        logpp_move = torch.distributions.Bernoulli(move).log_prob(target_move).mean(-1)

        log_std = torch.clip(
            self.log_std_logits, self.max_log_std, self.min_log_std
        )  # min is max in yaml
        std = torch.exp(log_std)
        logpp_rot = torch.distributions.Independent(
            torch.distributions.Normal(rot, std), reinterpreted_batch_ndims=1
        ).log_prob(target_rot)
        return logpp_move+logpp_rot

    # used when called in the hierarchical policy
    @torch.no_grad()
    def _action(self, event, low_pi=False, **kwargs):
        if (not self.is_eval) and "eval" in kwargs:
            self.eval()
            self.is_eval = True
        # # WARNING dirty fix for goal position height issue due to dirty dataset TODO clean fix
        # if not low_pi:
        #     if len(frame['sensor/absolute_goal_position'].size()) >= 2:
        #         frame['sensor/absolute_goal_position'][-1, 1] = 0.2637  # make it same height as player height
        #     else:
        #         frame['sensor/absolute_goal_position'][1] = 0.2637  # make it same height as player height

        # normalize observations before feeding them to the model
        # if self.bounds is not None:
        #     print("SHOULD NOT")
        #     self.norm(frame)

        # goal = frame['sensor/absolute_goal_position']
        event={k:v.unsqueeze(0) for k,v in event.items()}
        move, rot = self(event)
        # construct obs
        # pos = frame["sensor/position"]
        # cos_rot = frame["sensor/rotation"].cos()
        # sin_rot = frame["sensor/rotation"].sin()
        # if self.raycasts:
        #     if low_pi:
        #         raycasts = frame["sensor/raycasts"].view(frame["sensor/raycasts"].size(0), -1)
        #     else:
        #         raycasts = frame["sensor/raycasts"].flatten() #.view(frame["sensor/raycasts"].size(0), -1)
        #     obs = torch.cat([pos, cos_rot, sin_rot, raycasts], dim=-1)
        # else:
        #     obs = torch.cat([pos, cos_rot, sin_rot], dim=-1)
        # obs = obs.unsqueeze(0) if len(obs.shape) == 1 else obs

        # if len(goal.size()) == 1:
        #     goal = goal.unsqueeze(0)
        # move, rot = self(torch.cat((goal, obs), dim=1))
        # TODO implement stochastic variant where rot is sampled, not just using mean
        move = torch.distributions.Bernoulli(move).sample()
        cos = rot[0][0]
        sin = rot[0][1]
        angle = math.atan2(sin.item(), cos.item())
        #print(angle)
        action = {}
        action["move_right"] = move[0][0].item()
        action["move_left"] = move[0][1].item()
        action["move_forwards"] = move[0][2].item()
        action["move_backwards"] = move[0][3].item()
        action["run"] = True
        action["jump"] = False
        action["rotation"] = angle
        return action

    def reset(self, seed):
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
