# -*- coding: utf-8 -*-

import random
import time
import os
import sys
import numpy as np

import math
# -----image-------------
# import cv2
# -----multiprocessing---
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable



class obs_model(nn.Module):
    def __init__(self, out_num):
        super(obs_model, self).__init__()
        self.conv1 = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=6, stride=4),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2)
            )
        self.conv2 = nn.Sequential(
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding = 1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2)
            )
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )


        self.fc1 = nn.Linear(9216, 2560)
        self.fc2 = nn.Linear(2560, 512)
        self.fc3 = nn.Linear(512, out_num)


    def forward(self, x, mlt = False):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        if mlt:
            x = x.reshape(x.shape[0],-1)
        else:
            x = x.reshape(1, -1)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return torch.tanh(x)




class combine_model(nn.Module):
    def __init__(self, input_dim):
        super(combine_model, self).__init__()
        self.h1 = 300
        self.h11 = 300
        self.h2 = 200
        self.fs1 = nn.Linear(input_dim, self.h1)
        self.r1 = nn.ReLU()
        self.fa1 = nn.Linear(input_dim, self.h11)
        self.r11 = nn.ReLU()
        self.f2 = nn.Linear(self.h1+self.h11, input_dim)
        self.r2 = nn.ReLU()

    def forward(self,state, goal):
        x = self.r1(self.fs1(state))
        a1 = self.r11(self.fa1(goal))
        x2 = torch.cat((x, a1), dim=-1)
        x = self.r2(self.f2(x2))
        return x


class time_model(nn.Module):
    def __init__(self, input_dim, out_dim, pre = False, hid_siz = None):
        super(time_model, self).__init__()
        self.input_dim = input_dim

        self.preprocess = pre
        if hid_siz == None:
            self.h_size = 256
        else:
            self.h_size = hid_siz
        if self.preprocess:
            self.f0 = nn.Linear(self.input_dim, 100)
            self.f1 = nn.LSTM(
                input_size = 100,
                hidden_size = self.h_size,
                num_layers = 1,
                # batch_first = True
            )
        else:
            self.f1 = nn.LSTM(
                input_size = self.input_dim,
                hidden_size = self.h_size,
                num_layers = 1,
                # batch_first = True
            )
        self.f2 = nn.Linear(self.h_size, out_dim)
    def hidden_reset(self, batch_size = None):
        self.h0 = torch.zeros(1, 1, self.h_size).cuda().float()
        self.c0 = torch.zeros(1, 1, self.h_size).cuda().float()
        if batch_size != None:
            self.h0_batch = torch.zeros(1, batch_size, self.h_size).cuda().float()
            self.c0_batch = torch.zeros(1, batch_size, self.h_size).cuda().float()

    def forward(self, input, ins = False, lan = False, batch = False):
        if self.preprocess:
            if lan:
                y = self.f0(input)
                y = torch.transpose(y,0,1)
                if ins:
                        if batch:
                            x, (self.hn, self.cn) = self.f1(y, (self.h0_batch, self.c0_batch))
                        else:
                            x, (self.hn, self.cn) = self.f1(y, (self.h0, self.c0))
                        x = self.hn
                else:
                    for i in range(input.shape[0]):
                        y_ = y[i,:].view(1,1,-1)
                        if i == 0:
                            x, (self.h0, self.c0) = self.f1(y_, None)
                        else:
                            x, (self.h0, self.c0) = self.f1(y_, (self.h0, self.c0))
            else:
                y = self.f0(input).view(1, 1, -1)
                x, (self.h0, self.c0) = self.f1(y, None)
            
        else:
            if lan:
                y = input
                y_ = y.unsqueeze(1)
                x, (self.h0, self.c0) = self.f1(y_, None)
                x=x.squeeze()
            else:
                y = input.view(1, 1, -1)
                x, (self.h0, self.c0) = self.f1(y, None)
                x = x.view(1,-1)
        x = self.f2(x)
        return x
