from torch import nn
from torch.nn.utils import spectral_norm

from federatedscope.autotune.utils import arm2dict


class EncNet(nn.Module):
    def __init__(self, in_channel, out_channel, hid_dim=64):
        super(EncNet, self).__init__()

        self.fc_layer = nn.Sequential(
            spectral_norm(nn.Linear(in_channel, hid_dim, bias=False)),
            nn.ReLU(inplace=True),
            spectral_norm(nn.Linear(hid_dim, out_channel, bias=False)),
            nn.ReLU(inplace=True),
        )

    def forward(self, client_enc):
        mean_update = self.fc_layer(client_enc)
        return mean_update


class HyperNet(nn.Module):
    def __init__(
        self,
        input_dim,
        sizes,
        n_clients,
        device,
    ):
        super(HyperNet, self).__init__()
        self.EncNet = EncNet(input_dim, 32)
        self.out = nn.ModuleList()
        for num_cate in sizes:
            self.out.append(
                nn.Sequential(nn.Linear(32, num_cate, bias=True),
                              nn.Softmax()))

    def forward(self, encoding):
        client_enc = self.EncNet(encoding)
        probs = []
        for module in self.out:
            out = module(client_enc)
            probs.append(out)
        return probs


if __name__ == "__main__":
    import yaml
    import argparse
    import torch

    parser = argparse.ArgumentParser(description='Interpret learned policy')
    parser.add_argument('--ss_path', type=str, default='')
    parser.add_argument('--log_path', type=str, default='')
    parser.add_argument('--pt_path', type=str, default='')
    parser.add_argument('--save_path', type=str, default='')
    args = parser.parse_args()

    with open(args.ss_path, 'r') as ips:
        arms = yaml.load(ips, Loader=yaml.FullLoader)
        print(arms)
    with open(args.log_path, 'r') as ips:
        ckpt = yaml.load(ips, Loader=yaml.FullLoader)
    stop_exploration = ckpt['stop']
    print("stop: {}".format(stop_exploration))

    psn_pi = torch.load(args.pt_path, map_location='cpu')
    client_encodings = psn_pi['client_encodings']
    policy_net = HyperNet(
        client_encodings.shape[-1],
        [len(arms)],
        client_encodings.shape[0],
        'cpu',
    ).to('cpu')
    policy_net.load_state_dict(psn_pi['policy_net'])
    policy_net.eval()
    prbs = policy_net(client_encodings)
    prbs = prbs[0].detach().numpy()
    clientwise_configs = dict()
    for i in range(prbs.shape[0]):
        arm_idx = prbs[i].argmax()
        clientwise_configs['client_{}'.format(i + 1)] = arm2dict(
            arms['arm{}'.format(arm_idx)])
    with open(args.save_path, 'w') as ops:
        yaml.Dumper.ignore_aliases = lambda *args: True
        yaml.dump(clientwise_configs, ops)
