from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Callable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class WeightNet(nn.Module):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__()
        self.device = device
        self.name = 'weight_net'
    
    def clip_coeff(self):
        pass

    def get_train_data(self, x):
        return self.forward(x)

class OracleRewardWeight(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.coeff = torch.ones(n_reward_components, device=self.device)
        self.coeff[0] = 0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

class TanhWeightv1(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.register_parameter("coeff", nn.Parameter(torch.rand(n_reward_components, device=self.device)))
        self.name = 'tanhv1'

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return (torch.tanh(self.coeff.unsqueeze(0).expand(x.shape[0], -1))+1)/2

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

class TanhWeightv2(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.register_parameter("coeff", nn.Parameter(torch.rand(n_reward_components, device=self.device)))
        self.name = 'tanhv2'

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.tanh(self.coeff.unsqueeze(0).expand(x.shape[0], -1))

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

class SoftmaxWeightv1(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.register_parameter("coeff", nn.Parameter(torch.rand(n_reward_components, device=self.device)))
        self.name = 'softmaxv1'

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = F.softmax(self.coeff, dim=-1)
        return w.unsqueeze(0).expand(x.shape[0], -1)

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)


class SoftmaxWeightv2(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.register_parameter("coeff", nn.Parameter(torch.rand(n_reward_components, device=self.device)))
        self.name = 'softmaxv2'

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = F.softmax(self.coeff.abs(), dim=-1)
        sgn = self.coeff / (self.coeff.abs() + 1e-8)
        w = w * sgn
        return w.unsqueeze(0).expand(x.shape[0], -1)

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

class SimpleRewardWeightv1(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.register_parameter("coeff", nn.Parameter(torch.rand(n_reward_components, device=self.device)))
        self.name = 'simplev1'

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

    def clip_coeff(self):
        self.coeff.data.clamp_(0., 1.0)

class SimpleRewardWeightv2(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.register_parameter("coeff", nn.Parameter(torch.rand(n_reward_components, device=self.device)*2-1))
        self.name = 'simplev2'

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

    def clip_coeff(self):
        self.coeff.data.clamp_(-1.0, 1.0)

class DirectRewardWeight(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.weight_net = nn.Linear(features_dim, n_reward_components).to(self.device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.weight_net(x)
        return 1+torch.tanh(x)*0.5

    def reset_weight(self, coeff: torch.Tensor):
        self.weight_net.reset_parameters()

    def get_weight(self, x):
        return self.weight_net(x)

class RegulatedRewardWeightv2(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.weight_net = nn.Linear(features_dim, n_reward_components).to(self.device)
        self.register_parameter("coeff", nn.Parameter(torch.ones(n_reward_components, device=self.device)))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.weight_net(x)
        return (1+torch.tanh(x)*0.5) * self.coeff

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

class RegulatedRewardWeightv1(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.weight_net = nn.Linear(features_dim, n_reward_components).to(self.device)
        self.register_parameter("coeff", nn.Parameter(torch.ones(n_reward_components, device=self.device)))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.weight_net(x)
        return torch.tanh(x) * self.coeff

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

class RegulatedRewardWeightv3(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.weight_net = nn.Linear(features_dim, n_reward_components).to(self.device)
        self.register_parameter("coeff", nn.Parameter(torch.rand(n_reward_components, device=self.device)))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.weight_net(x)
        return (0.5+torch.tanh(x)*0.5) * self.coeff

    def clip_coeff(self):
        self.coeff.data.clamp_(0.0, 1.0)

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

class RegulatedRewardWeightv4(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.weight_net = nn.Linear(features_dim, n_reward_components).to(self.device)
        self.register_parameter("coeff", nn.Parameter(torch.rand(n_reward_components, device=self.device)*2-1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.weight_net(x)
        return (0.5+torch.tanh(x)*0.5) * self.coeff

    def clip_coeff(self):
        self.coeff.data.clamp_(-1.0, 1.0)

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

class RegulatedRewardWeightv5(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.weight_net = nn.Linear(features_dim, n_reward_components).to(self.device)
        self.register_parameter("coeff", nn.Parameter(torch.rand(n_reward_components, device=self.device)))
        self.set_weight_net_weight()

    def set_weight_net_weight(self):
        with torch.no_grad():
            self.weight_net.weight.fill_(0)
            self.weight_net.bias.fill_(0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.weight_net(x)
        return (1+torch.tanh(x)*0.5) * self.coeff

    def clip_coeff(self):
        self.coeff.data.clamp_(0.0, 1.0)

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)
        self.set_weight_net_weight()

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

class RegulatedRewardWeightv6(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.weight_net = nn.Linear(features_dim, n_reward_components).to(self.device)
        self.register_parameter("coeff", nn.Parameter(torch.rand(n_reward_components, device=self.device)*2-1))
        self.set_weight_net_weight()

    def set_weight_net_weight(self):
        with torch.no_grad():
            self.weight_net.weight.fill_(0)
            self.weight_net.bias.fill_(0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.weight_net(x)
        return (1+torch.tanh(x)*0.5) * self.coeff

    def clip_coeff(self):
        self.coeff.data.clamp_(-1.0, 1.0)

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)
        self.set_weight_net_weight()

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

class RegulatedRewardWeightv7(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.weight_net = nn.Linear(features_dim, n_reward_components-1).to(self.device)
        self.register_parameter("coeff", nn.Parameter(torch.rand(n_reward_components-1, device=self.device)))
        self.set_weight_net_weight()

    def set_weight_net_weight(self):
        with torch.no_grad():
            self.weight_net.weight.fill_(0)
            self.weight_net.bias.fill_(0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.weight_net(x)
        w2 = (1+torch.tanh(x)*0.5) * self.coeff
        w1 = torch.ones_like(w2[..., 0:1])
        w = torch.cat([w1, w2], dim=1)
        return w

    def clip_coeff(self):
        self.coeff.data.clamp_(0.0, 1.0)

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)
        self.set_weight_net_weight()

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

    def get_train_data(self, x):
        x = self.weight_net(x)
        return (1+torch.tanh(x)*0.5) * self.coeff

class RegulatedRewardWeightv8(WeightNet):
    def __init__(self, n_reward_components: int, features_dim: int, device: str = 'cpu'):
        super().__init__(n_reward_components, features_dim, device)
        self.weight_net = nn.Linear(features_dim, n_reward_components-1).to(self.device)
        self.register_parameter("coeff", nn.Parameter(torch.rand(n_reward_components-1, device=self.device)*2-1))
        self.set_weight_net_weight()

    def set_weight_net_weight(self):
        with torch.no_grad():
            self.weight_net.weight.fill_(0)
            self.weight_net.bias.fill_(0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.weight_net(x)
        w2 = (1+torch.tanh(x)*0.5) * self.coeff
        w1 = torch.ones_like(w2[..., 0:1])
        w = torch.cat([w1, w2], dim=1)
        return w

    def clip_coeff(self):
        self.coeff.data.clamp_(-1.0, 1.0)

    def reset_weight(self, coeff: torch.Tensor):
        self.coeff.data.copy_(coeff)
        self.set_weight_net_weight()

    def get_weight(self, x):
        return self.coeff.unsqueeze(0).expand(x.shape[0], -1)

    def get_train_data(self, x):
        x = self.weight_net(x)
        return (1+torch.tanh(x)*0.5) * self.coeff

def check_parameter_exists(model, parameter_name):
    for name, param in model.named_parameters():
        if name == parameter_name:
            return True
    return False

if __name__ == "__main__":
    device = 'cuda:0'
    model = RegulatedRewardWeightv3(3, 16, device)
    small_lr_group = list([])
    large_lr_group = list([])

    if check_parameter_exists(model, 'weight_net'):
        small_lr_group += list(model.weight_net.parameters())
    if check_parameter_exists(model, 'coeff'):
        large_lr_group.append(model.coeff)

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name)
    optimizer = torch.optim.Adam(
        [
            {'params': large_lr_group, 'lr': 5e-2},
            {'params': small_lr_group, 'lr': 1e-3}
        ]
    )

    # test of gradient
    device = 'cuda:0'
    model = SimpleRewardWeightv1(3, 16, device)
    obs_i = torch.ones(3, device=device)
    ret_i = torch.ones((1,3), device=device)
    reward_weights_i = model(obs_i)
    weighted_return_i = torch.sum(ret_i * reward_weights_i, dim=1)

    grad_phi_i = torch.autograd.grad(
        outputs=weighted_return_i,
        inputs=list(model.parameters()),
        grad_outputs=torch.ones_like(weighted_return_i),
        retain_graph=True,
        create_graph=True
    )