import numpy as np
import random
import math
import torch.nn.functional as F
from torch import nn

import params
import torch
from cuda import use_cuda
from env.statement import num_statements
from env.operator import num_operators
from model.encoder import DenseEncoder


class BaseModel(nn.Module):
    def load(self, path):
        if use_cuda:
            params = torch.load(path)
        else:
            params = torch.load(path, map_location=lambda storage, loc: storage)

        state = self.state_dict()
        for name, val in params.items():
            if name in state:
                assert state[name].shape == val.shape, "%s size has changed from %s to %s" % \
                                                       (name, state[name].shape, val.shape)
                state[name].copy_(val)
            else:
                print("WARNING: %s not in model during model loading!" % name)

    def save(self, path):
        torch.save(self.state_dict(), path)


class PCCoder(BaseModel):
    def __init__(self):
        super(PCCoder, self).__init__()
        self.encoder = DenseEncoder()
        self.statement_head = nn.Linear(params.dense_output_size, num_statements)
        self.drop_head = nn.Linear(params.dense_output_size, params.max_program_vars)
        self.operator_head = nn.Linear(params.dense_output_size, num_operators)

    def forward(self, x, get_operator_head=True):
        x = self.encoder(x)
        if get_operator_head:
            return self.statement_head(x), torch.sigmoid(self.drop_head(x)), self.operator_head(x)
        else:
            return self.statement_head(x), torch.sigmoid(self.drop_head(x))

    def predict(self, x):
        statement_pred, drop_pred, _ = self.forward(x)
        statement_probs = F.softmax(statement_pred, dim=1).data
        drop_indx = np.argmax(drop_pred.data.cpu().numpy(), axis=-1)
        return np.argsort(statement_probs.cpu().numpy()), statement_probs, drop_indx


class Generator:
    def __init__(self, integer_min=None, integer_max=None, list_len_min=None, list_len_max=None):
        super().__init__()
        #integer_midpoint = 0 # random.randint(params.integer_min, params.integer_max)
        #integer_range = min(round(np.random.exponential(100)), 255) # random.randint(0, params.integer_range) // 2
        integer_bounds = np.random.normal(scale=10, size=2)
        self.integer_min = integer_min if integer_min is not None else max(params.integer_min, int(round(min(integer_bounds)))) #max(params.integer_min, integer_midpoint - integer_range)
        self.integer_max = integer_max if integer_max is not None else min(params.integer_max, int(round(max(integer_bounds)))) #min(params.integer_max, integer_midpoint + integer_range)
        list_len_midpoint = random.randint(params.min_list_len, params.max_list_len)
        list_len_range = random.randint(0, params.list_len_range) // 2
        self.list_len_min = list_len_min if list_len_min is not None else max(params.min_list_len, list_len_midpoint - list_len_range)
        self.list_len_max = list_len_max if list_len_max is not None else min(params.max_list_len, list_len_midpoint + list_len_range)

    def mutate(self, i):
        int_incr = math.ceil(250 * 0.5 ** i)
        list_incr = math.ceil(20 * 0.5 ** i)
        self.integer_min = min(max(params.integer_min, self.integer_min + random.randint(-1 * int_incr, int_incr)), self.integer_max)
        self.integer_max = max(min(params.integer_max, self.integer_max + random.randint(-1 * int_incr, int_incr)), self.integer_min)
        self.list_len_min = min(max(params.min_list_len, self.list_len_min + random.randint(-1 * list_incr, list_incr)), self.list_len_max)
        self.list_len_max = max(min(params.max_list_len, self.list_len_max + random.randint(-1 * list_incr, list_incr)), self.list_len_min)

    def save(self, path):
        params = np.array([self.integer_min, self.integer_max, self.list_len_min, self.list_len_max])
        np.savetxt(path, params, fmt='%s')

    def load(self, path):
        params = [int(x) for x in np.loadtxt(path).tolist()]
        self.__init__(*params)
