import torch
import math
import typing as tp
import torch.nn.functional as F
from torch import nn

from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

class _L2(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x):
        y = math.sqrt(self.dim) * F.normalize(x, dim=1)
        return y
    

class RandomShiftsAug(nn.Module):
    def __init__(self, pad) -> None:
        super().__init__()
        self.pad = pad

    def forward(self, x) -> torch.Tensor:
        x = x.float()
        n, _, h, w = x.size()
        assert h == w
        padding = tuple([self.pad] * 4)
        x = F.pad(x, padding, 'replicate')
        eps = 1.0 / (h + 2 * self.pad)
        arange = torch.linspace(-1.0 + eps,
                                1.0 - eps,
                                h + 2 * self.pad,
                                device=x.device,
                                dtype=x.dtype)[:h]
        arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
        base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
        base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)

        shift = torch.randint(0,
                              2 * self.pad + 1,
                              size=(n, 1, 1, 2),
                              device=x.device,
                              dtype=x.dtype)
        shift *= 2.0 / (h + 2 * self.pad)

        grid = base_grid + shift
        return F.grid_sample(x,
                             grid,
                             padding_mode='zeros',
                             align_corners=False)


def _nl(name: str, dim: int) -> tp.List[nn.Module]:
    """Returns a non-linearity given name and dimension"""
    if name == "irelu":
        return [nn.ReLU(inplace=True)]
    if name == "relu":
        return [nn.ReLU()]
    if name == "ntanh":
        return [nn.LayerNorm(dim), nn.Tanh()]
    if name == "layernorm":
        return [nn.LayerNorm(dim)]
    if name == "tanh":
        return [nn.Tanh()]
    if name == "L2":
        return [_L2(dim)]
    raise ValueError(f"Unknown non-linearity {name}")


def mlp(*layers: tp.Sequence[tp.Union[int, str]], reward = False) -> nn.Sequential:
    """Provides a sequence of linear layers and non-linearities
    providing a sequence of dimension for the neurons, or name of
    the non-linearities
    Eg: mlp(10, 12, "relu", 15) returns:
    Sequential(Linear(10, 12), ReLU(), Linear(12, 15))
    """
    assert len(layers) >= 2
    sequence: tp.List[nn.Module] = []
    assert isinstance(layers[0], int), "First input must provide the dimension"
    prev_dim: int = layers[0]
    for layer in layers[1:]:
        if isinstance(layer, str):
            sequence.extend(_nl(layer, prev_dim))
        else:
            assert isinstance(layer, int)
            sequence.append(nn.Linear(prev_dim, layer))
            prev_dim = layer
    
    return nn.Sequential(*sequence)

class RewardNetwork:
    def __init__(self, device, hidden_dim = 1024, use_action = False):

        self.encoder = make_reward_encoder(device)
        self.obs_dim = self.encoder.repr_dim
        self.encoder.to(device)
        self.use_action = use_action
        if use_action:
            reward_layers = [self.obs_dim * 2 + 12, hidden_dim, "relu", hidden_dim, "relu", 1]
        else:
            reward_layers = [self.obs_dim * 2, hidden_dim, "relu", hidden_dim, "relu", 1]
        self.reward_net = mlp(*reward_layers)
        self.reward_net.to(device)
        self.tanh = nn.Tanh()
        self.device = device

        self.aug = RandomShiftsAug(pad=(64 // 21))
        self.criterion = nn.BCEWithLogitsLoss()
        self.optimizer = torch.optim.Adam(list(self.encoder.parameters()) + list(self.reward_net.parameters()), lr=0.0001)

    def get_preference(self, obs, next_obs, action):
        obs = obs.reshape(-1, 3, 64, 64, 3)
        obs = obs[:, 1, :, :, :].permute(0, 3, 1, 2)
        next_obs = next_obs.reshape(-1, 3, 64, 64, 3)
        next_obs = next_obs[:, 1, :, :, :].permute(0, 3, 1, 2)
        #obs = obs.reshape(-1, 64, 64, 3).permute(0, 3, 1, 2)
        #next_obs = next_obs.reshape(-1, 64, 64, 3).permute(0, 3, 1, 2)
        with torch.no_grad():
            obs = self.encoder(obs)
            next_obs = self.encoder(next_obs)
            if self.use_action:
                pred_reward = self.reward_net(torch.cat([obs, next_obs, action], dim=1))
            else:
                pred_reward = self.reward_net(torch.cat([obs, next_obs], dim=1))
        
        pred_reward = self.tanh(pred_reward)
        m = nn.Sigmoid()
        pred_reward = m(pred_reward)

        return pred_reward
    
    def eval_preference(self, obs, next_obs, action):
        with torch.no_grad():
            obs = self.encoder(obs)
            next_obs = self.encoder(next_obs)
            if self.use_action:
                pred_reward = self.reward_net(torch.cat([obs, next_obs, action], dim=1))
            else:
                pred_reward = self.reward_net(torch.cat([obs, next_obs], dim=1))
        
        pred_reward = self.tanh(pred_reward)
        m = nn.Sigmoid()
        pred_reward = m(pred_reward)

        return pred_reward
    
    def update_reward(self, obs, next_obs, reward):
        obs = self.encoder(self.aug(obs.to(self.device)))
        next_obs = self.encoder(self.aug(next_obs.to(self.device)))
        pred_reward = self.reward_net(torch.cat([obs, next_obs], dim=1))
        pred_reward = self.tanh(pred_reward)
        target_reward = reward.to(self.device).float()

        loss = self.criterion(pred_reward.squeeze(), target_reward)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        predictions = (pred_reward.squeeze().detach().cpu().numpy() > 0.5).astype(int)
        accuracy = accuracy_score(reward.cpu().numpy(), predictions)
        recall = recall_score(reward.cpu().numpy(), predictions)
        precision = precision_score(reward.cpu().numpy(), predictions)
        f1 = f1_score(reward.cpu().numpy(), predictions)

        metrics = {
                #'predicted logits':pred_reward.squeeze().detach().cpu().numpy()[:8],
                'reward_loss': loss,
                'accuracy': torch.tensor(accuracy),
                'recall': torch.tensor(recall),
                'precision': torch.tensor(precision),
                'f1_score': torch.tensor(f1)
            }
        
        return loss, metrics
        
class RewardEncoder(nn.Module):
    def __init__(self, obs_shape) -> None:
        super().__init__()

        assert len(obs_shape) == 3
        self.repr_dim = None  # To be specified later

        if obs_shape[1] >= 64:
            self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
                                         nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                         nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                         nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                         nn.ReLU())
        elif obs_shape[1] >= 48:
            self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
                                         nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                         nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                         nn.ReLU())
        else:
            self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
                                         nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                         nn.ReLU())
            
        self.projection = nn.Linear(20000, 512)

    def forward(self, obs):
        obs = obs / 255.0 - 0.5
        h = self.convnet(obs)
        h = h.reshape(h.shape[0], -1)
        h = self.projection(h)
        return h
    
def make_reward_encoder(device):
    obs_shape = [3, 64, 64]
    encoder = RewardEncoder(obs_shape).to(device)
    example_ob = torch.zeros(1, *obs_shape, device=device)
    module_obs_dim = encoder(example_ob).shape[-1]
    encoder.repr_dim = module_obs_dim

    return encoder

def hard_update_params(net, net_name, target_net, target_net_name, use_preference=False) -> None:
    for (name, param), (target_name, target_param) in zip(net.named_parameters(), target_net.named_parameters()):
        print(f"Updating parameter: {net_name} parameter {name} to {target_net_name} parameter {target_name}")
        target_param.data.copy_(param.data)
