import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float


def param_init(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight.data, 0, 0.01)
        if isinstance(m, nn.ConvTranspose2d):
            nn.init.normal_(m.weight.data, 0, 0.01)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight.data, 1)
            nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight.data, 0, 0.01)
            if m.bias != None:
                nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.constant_(m.weight.data, 1)
            nn.init.constant_(m.bias.data, 0)


class conv_layer_module(nn.Module):
    def __init__(self, in_ch, out_ch, k, s, p, bias=False):
        super(conv_layer_module, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch,
                              kernel_size=k, stride=s, padding=p, bias=bias)
        self.bat = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bat(x)
        x = self.act(x)
        return x


class conv_transposed_layer_module(nn.Module):
    def __init__(self, in_ch, out_ch, k, s, p, op, bias=False):
        super(conv_transposed_layer_module, self).__init__()
        self.conv_t = nn.ConvTranspose2d(in_channels=in_ch, out_channels=out_ch, kernel_size=k,
                                         stride=s, padding=p, output_padding=op, bias=bias)
        self.bat = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv_t(x)
        x = self.bat(x)
        x = self.act(x)
        return x


class fc_layer_module(nn.Module):
    def __init__(self, in_dim, out_dim, bias=False):
        super(fc_layer_module, self).__init__()
        self.fc = nn.Linear(in_dim, out_dim, bias)
        self.bat = nn.BatchNorm1d(out_dim)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.fc(x)
        x = self.bat(x)
        x = self.act(x)
        return x



class FeatNet(nn.Module):
    def __init__(self):
        super(FeatNet, self).__init__()
        self.conv = nn.Sequential(
            conv_layer_module(6, 9, 3, 2, 1),
            conv_layer_module(9, 16, 3, 2, 1),
            conv_layer_module(16, 32, 3, 2, 1),
            conv_layer_module(32, 64, 3, 2, 1),
            nn.Flatten(1))
        self.proj = nn.Sequential(
            fc_layer_module(256, 256),
            nn.Linear(256, 128))
        
        param_init(self)

    def forward(self, o):
        z = F.normalize(self.proj(self.conv(o)), dim=1)
        return z



class CoordNet(nn.Module) : 
    def __init__(self) : 
        super(CoordNet, self).__init__()
        self.conv = nn.Sequential(
            conv_layer_module(6, 9, 3, 2, 1),
            conv_layer_module(9, 16, 3, 2, 1),
            conv_layer_module(16, 32, 3, 2, 1),
            conv_layer_module(32, 64, 3, 2, 1),
            nn.Flatten(1))
        self.proj = nn.Sequential(
            fc_layer_module(256, 128))
        self.head = nn.Sequential(
            nn.Linear(128, 2),
            nn.Tanh())
        
        param_init(self)
        
    def forward(self, o) : 
        z = F.normalize(self.proj(self.conv(o)), dim=1)
        c = self.head(z)
        return z, c



class dirPolNet(nn.Module) : 
    def __init__(self) : 
        super(dirPolNet, self).__init__()
        self.dirpolnet = nn.Sequential(
            fc_layer_module(128*2 + 128*2 + 1 + 1, 512),
            fc_layer_module(512, 256),
            nn.Linear(256, 8))
        
        param_init(self)
        
    def forward(self, z_ft, z_cd, prior_ft, prior_cd, lambda_ft, lambda_cd) : 
        prob = self.dirpolnet(torch.cat((z_ft, z_cd, prior_ft, prior_cd, lambda_ft, lambda_cd), dim=1))
        return prob



class Agent(nn.Module):
    def __init__(self, batch_size):
        super(Agent, self).__init__()
        self.img_size = [255, 255]
        self.view_size = [27, 27]
        
        self.max_stepsize = 30
        self.batch_size = batch_size
        
        self.coordnet = CoordNet()
        self.featnet = FeatNet()
    
    
    def set_prior(self, state):
        self.featnet.load_state_dict(state['featnet'])
        self.prior_z_ft = state['prior_z_ft']
        self.coordnet.load_state_dict(state['coordnet'])
        self.prior_z_cd = state['prior_z_cd']
            
    
    def random_lambda_set(self) : 
        self.lambda_ft = 0.001 * torch.randint(0, 1000, (self.batch_size,1)).to(device)
        self.lambda_cd = 1 - self.lambda_ft
    
    
    def make_training_sample(self, env, o0, l_idx):
        self.random_lambda_set()
        
        D_ft0, D_cd0, inferred_coord0,\
            embedding_ft0, embedding_cd0, prior_ft, prior_cd = self.calc_D(o0, l_idx)
        D0 = self.calc_distance(D_ft0, D_cd0)
        
        self.stepsize_calc(D0)
        act_to_env = self.act_to_env_calc()
        
        o1_cand = env.apply_action_and_o_extraction(act_to_env)
        D_ft1, D_cd1, _, _, _, _, _ = self.calc_D(o1_cand.view(-1,3,27,27), l_idx, cand=True)
        D1_cand = self.calc_distance(D_ft1, D_cd1, cand=True)
        
        direction = torch.argmin(D1_cand.view(-1,8), dim=1)
        lambda_ft = self.lambda_ft.clone()
        
        return embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction
    
    
    def calc_D(self, o, l_idx, cand=False):
        if cand : 
            l_idx = l_idx.view(-1,1).repeat(1,8).view(-1)
        
        embedding_ft = self.featnet(o[:, :2])
        selected_prior_ft = self.prior_z_ft[l_idx]
        D_ft = torch.norm(embedding_ft - selected_prior_ft, dim=1)**2
        
        embedding_cd, inferred_coord = self.coordnet(o[:, 1:])
        selected_prior_cd = self.prior_z_cd[l_idx]
        D_cd = torch.norm(embedding_cd - selected_prior_cd, dim=1)**2
        
        return D_ft, D_cd, inferred_coord,\
            embedding_ft, embedding_cd, selected_prior_ft, selected_prior_cd
    
    
    def calc_distance(self, D_ft, D_cd, cand=False):
        lambda_ft = self.lambda_ft.repeat(1,8).view(-1,1) if cand else self.lambda_ft
        lambda_cd = self.lambda_cd.repeat(1,8).view(-1,1) if cand else self.lambda_cd
        D = lambda_ft * D_ft.unsqueeze(1) + lambda_cd * D_cd.unsqueeze(1)
        return D
    
        
    def stepsize_calc(self, D0) : 
        thresholds = 0.001 + 0.003 * torch.arange(self.max_stepsize - 1).to(device)
        self.stepsize = (D0.view(-1,1) > thresholds).sum(dim=1) + 1
        
        
    def act_to_env_calc(self) : 
        act_to_env = torch.zeros(self.batch_size, 8, 2).long().to(device)
        act_to_env += torch.LongTensor([[-1,-1], [-1,0], 
                                        [-1,1], [0,-1],
                                        [0,1], [1,-1], 
                                        [1,0], [1,1]]).to(device) * self.stepsize.view(-1, 1, 1)
        return act_to_env
    
    


