import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import print

# w/ softmax policy
class TabularSoccer(nn.Module):
    def __init__(self, nstate=9*6*9*6*2, nact=5, nvf=1):
        super().__init__()
        # self.policy = nn.Parameter(torch.zeros((nstate, nact)))
        # self.value = nn.Parameter(torch.zeros((nstate,)))
        self.policy = nn.Parameter(1.0*torch.randn((nstate, nact)))
        self.value = nn.Parameter(torch.zeros((nstate,)))
        # self.value = nn.Parameter(0.1*torch.randn((nstate,)))
    def forward(self, s, av=3):
        with torch.no_grad():
            idx = s[...,0] + 9*(s[...,1] + 6*(s[...,2] + 9*(s[...,3] + 6*s[...,4])))
            idx = idx.to(torch.long)
        if av == 1:
            act_logits = torch.index_select(self.policy, 0, idx)
            return act_logits
        elif av == 2:
            values = torch.index_select(self.value, 0, idx)
            return values
        else:
            act_logits = torch.index_select(self.policy, 0, idx)
            values = torch.index_select(self.value, 0, idx)
            return act_logits, values


class MLPSoccer(nn.Module):
    def __init__(self, nstate=32, nact=5, nhidden=[256,256],
            activation=nn.LeakyReLU, shared=True, nvf=1):
        super().__init__()

        h = nstate
        layers = []
        for nh in nhidden:
            layers.append(nn.Linear(h, nh))
            layers.append(activation())
            h = nh

        self.shared = shared
        if shared:
            self.mlp = nn.Sequential(*layers)

            self.policy = nn.Linear(h, nact)
            nn.init.orthogonal_(self.policy.weight, gain=1.0)
            nn.init.constant_(self.policy.bias, 0)

            self.value = nn.Linear(h, 1)
            nn.init.orthogonal_(self.value.weight, gain=0.01)
            nn.init.constant_(self.value.bias, 0)
        else:
            layers.append(nn.Linear(h, nact))
            nn.init.orthogonal_(layers[-1].weight, gain=1.0)
            nn.init.constant_(layers[-1].bias, 0)
            self.policy = nn.Sequential(*layers)

            h = nstate
            layers = []
            for nh in nhidden:
                layers.append(nn.Linear(h, nh))
                layers.append(activation())
                h = nh
            layers.append(nn.Linear(h, 1))
            nn.init.orthogonal_(layers[-1].weight, gain=0.01)
            nn.init.constant_(layers[-1].bias, 0)
            self.value = nn.Sequential(*layers)
        
    def forward(self, s, av=3):
        with torch.no_grad():
            s = s.to(torch.long)
            s_onehot = torch.zeros(list(s.shape[:-1]) + [32])
            # s_onehot.zero_()
            s_onehot.scatter_(-1, s[...,0:1], 1)
            s_onehot.scatter_(-1, 9+s[...,1:2], 1)
            s_onehot.scatter_(-1, 15+s[...,2:3], 1)
            s_onehot.scatter_(-1, 24+s[...,3:4], 1)
            s_onehot.scatter_(-1, 30+s[...,4:5], 1)

        if self.shared:
            x = self.mlp(s_onehot)
            if av == 1:
                return self.policy(x)
            elif av == 2:
                return self.value(x).squeeze(-1)
            else:
                return self.policy(x), self.value(x).squeeze(-1)

        if av == 1:
            return self.policy(s_onehot)
        elif av == 2:
            return self.value(s_onehot).squeeze(-1)
        else:
            act_logits = self.policy(s_onehot)
            values = self.value(s_onehot).squeeze(-1)
            return act_logits, values


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

# fully convolutional network for board game
class FCN(nn.Module):
    def __init__(self, board_size=9,
            nc=[16,32,64],  # channels
            nk=[5,5,3],     # kernels
            # ns=[1,1,1],   # strides
            # nh=[128],   # fc hidden sizes
            shared=True, nvf=1,
            pretrain=None,
            noise=False,
            freeze=0,
            verbose=1,
            pad_m1=False,
            # pad_side=False,
            ):
        super().__init__()

        h = 2
        # h = 3 if pad_size else 2
        layers = []
        for i, nh, k in zip(range(len(nc)), nc, nk):
            if pad_m1 and i == 0:
                layers.append(nn.ConstantPad2d(k//2, -1))
                layers.append(nn.Conv2d(h, nh, kernel_size=k, stride=1, padding=0))
            else:
                layers.append(nn.Conv2d(h, nh, kernel_size=k, stride=1, padding=k//2))
            if i < freeze:
                layers[-1].weight.require_grad = False
                layers[-1].bias.require_grad = False
                if verbose: print ('freeze ', i)
            layers.append(nn.ReLU())
            h = nh

        self.shared = shared
        self.board_size = board_size
        # self.pad_side = pad_side
        if shared:
            self.cnn = nn.Sequential(*layers)
            self.policy = nn.Conv2d(h, 1, kernel_size=1, stride=1, padding=0)
            # self.value = nn.Sequential(
            #     nn.Conv2d(h, 8, kernel_size=1, stride=1, padding=0),
            #     nn.ReLU(),
            #     # nn.Conv2d(16, 1, kernel_size=1, stride=1, padding=0),
            #     # nn.AdaptiveAvgPool2d(1),
            #     # nn.ReLU(),
            #     Flatten(),
            #     nn.Linear(81*8, 1)
            #     )
            self.value = nn.Linear(h, 1)
            # self.value = nn.Sequential(
            #     nn.AdaptiveAvgPool2d(1),
            #     nn.Linear(h, 1))
        else:
            layers.append(nn.Conv2d(h, 1, kernel_size=1, stride=1, padding=0))
            # nn.init.orthogonal_(layers[-1].weight, gain=0.01)
            # nn.init.constant_(layers[-1].bias, 0)
            self.policy = nn.Sequential(*layers)

            h = 2
            layers = []
            for nh, k in zip(nc, nk):
                layers.append(nn.Conv2d(h, nh, kernel_size=k, stride=1, padding=k//2))
                layers.append(nn.ReLU())
                h = nh
            # layers.append(nn.Conv2d(h, 1, kernel_size=board_size, stride=1, padding=0))
            layers.append(nn.AdaptiveAvgPool2d((1,1)))
            layers.append(nn.Conv2d(h, 1, kernel_size=1, stride=1, padding=0))
            self.value = nn.Sequential(*layers)

        if pretrain:
            # d0 = torch.load(pretrain, map_location=device)
            # d = {k:v for k,v in d0.items() if 'value' not in k}
            d = torch.load(pretrain)
            if 'cnn.0.weight' in d.keys() and not shared:
                print ('migrate from shared arch to unshared')
                d0, d = d, {}
                d['policy.0.weight'] = d0['cnn.0.weight']
                d['policy.0.bias']   = d0['cnn.0.bias']
                d['policy.2.weight'] = d0['cnn.2.weight']
                d['policy.2.bias']   = d0['cnn.2.bias']
                d['policy.4.weight'] = d0['cnn.4.weight']
                d['policy.4.bias']   = d0['cnn.4.bias']
                d['policy.6.weight'] = d0['policy.weight']
                d['policy.6.bias']   = d0['policy.bias']
            if verbose: print ('load pretrain: ' + pretrain, d.keys())
            self.load_state_dict(d, strict=False)

            if noise:
                def perturb(w):
                    w += torch.randn(w.shape) * w.std() * 0.5
                # print (self.policy.weight)
                perturb(self.policy.weight.data)
                # print (self.policy.weight)
                # perturb(self.policy.bias.data)
        else:
            self.value.weight.data *= 0.01
            self.value.bias.data.zero_()

    def forward(self, s, av=3):
        valid_move = s.eq(0).float()
        s = s.view(-1,self.board_size,self.board_size)
        sp0, sp1 = s.eq(1), s.eq(2)
        # valid_move, sp0, sp1 = s.eq(0).float(), s.eq(1), s.eq(2)
        # if self.pad_side:
        #     s = torch.stack((sp0, sp1, sp0), dim=1).float()
        # else:
        s = torch.stack((sp0, sp1), dim=1).float()

        if self.shared:
            x = self.cnn(s)
            if av == 1:
                act_logits = self.policy(x).view(-1, self.board_size**2)
                act_logits = act_logits * valid_move - (1-valid_move)*1e30
                return act_logits
            elif av == 2:
                return self.value(x.mean((2,3))).view(-1)
                # return self.value(x).view(-1)
            else:
                act_logits = self.policy(x).view(-1, self.board_size**2)
                act_logits = act_logits * valid_move - (1-valid_move)*1e30
                return act_logits, self.value(x.mean((2,3))).view(-1)
                # return act_logits, self.value(x).view(-1)

        if av == 1:
            act_logits = self.policy(s).view(-1, self.board_size**2)
            act_logits = act_logits * valid_move - (1-valid_move)*1e30
            return act_logits
        elif av == 2:
            values = self.value(s).view(-1)
            return values
        else:
            act_logits = self.policy(s).view(-1, self.board_size**2)
            act_logits = act_logits * valid_move - (1-valid_move)*1e30
            values = self.value(s).view(-1)
            return act_logits, values
