from functools import reduce
from operator import mul

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import math
from utils import preprocess_gradients
from layer_norm_lstm import LayerNormLSTMCell
from layer_norm import LayerNorm1D

class MetaOneStageOptimizer(nn.Module):

    def __init__(self, num_layers, hidden_size, batch_size, inputdim=3):
        super(MetaOneStageOptimizer, self).__init__()
        self.meta_xadv = Variable(torch.zeros([batch_size, 3, 32, 32])).cuda()

        self.hidden_size = hidden_size
        self.inputdim = inputdim
        self.num_layers = num_layers
        # default inputdim =3
        self.linear1 = nn.Linear(self.inputdim, hidden_size)
        self.ln1 = LayerNorm1D(hidden_size)

        self.lstms1 = LayerNormLSTMCell(hidden_size, hidden_size)
        self.lstms2 = LayerNormLSTMCell(hidden_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, 1)
        self.linear2.weight.data.mul_(0.1)
        self.linear2.bias.data.fill_(0.0)

    def reset_lstm(self, keep_states=False, xadv=None, use_cuda=True, device=torch.device('cuda')):
        self.meta_xadv = Variable(xadv.data).cuda()
        if keep_states:
            for i in range(self.num_layers):
                self.hx[i] = Variable(self.hx[i].data).cuda()
                self.cx[i] = Variable(self.cx[i].data).cuda()
        else:
            self.hx = []
            self.cx = []
            for i in range(self.num_layers):
                self.hx.append(Variable(torch.zeros(1, self.hidden_size)))
                self.cx.append(Variable(torch.zeros(1, self.hidden_size)))
                if use_cuda:
                    self.hx[i], self.cx[i] = self.hx[i].cuda(), self.cx[i].cuda()

    def forward(self, x):
        # Gradients preprocessing
        x = F.tanh(self.ln1(self.linear1(x)))

        # lstm1
        if x.size(0) != self.hx[0].size(0):
            self.hx[0] = self.hx[0].expand(x.size(0), self.hx[0].size(1))
            self.cx[0] = self.cx[0].expand(x.size(0), self.cx[0].size(1))
        self.hx[0], self.cx[0] = self.lstms1(x, (self.hx[0], self.cx[0]))
        x = self.hx[0]
        
        # lstm2
        if x.size(0) != self.hx[1].size(0):
            self.hx[1] = self.hx[1].expand(x.size(0), self.hx[1].size(1))
            self.cx[1] = self.cx[1].expand(x.size(0), self.cx[1].size(1))
        self.hx[1], self.cx[1] = self.lstms2(x, (self.hx[1], self.cx[1]))
        x = self.hx[1]

        x = self.linear2(x)
        return x.squeeze()

    def meta_update(self, x_adv, mode, idx, step_size, x_min, x_max, clip_by_tensor):
        # First we need to create a flat version of parameters and gradients
        grads = x_adv.grad
        flat_xadv = x_adv.view(-1).unsqueeze(-1)
        if self.inputdim == 3:
            flat_grads = preprocess_gradients(grads.view(-1).unsqueeze(-1))
            inputs = torch.cat((flat_grads, flat_xadv), 1)
        elif self.inputdim == 2:
            flat_grads = preprocess_gradients(grads.view(-1).unsqueeze(-1))
            inputs = flat_grads
        elif self.inputdim == 1:
            inputs = grads.clone().view(-1).unsqueeze(-1)
        else:
            raise IOError

        # Meta update itself
        if mode == "eval":
            with torch.no_grad():
                self.meta_xadv = self.meta_xadv - step_size * F.tanh(self(inputs).unsqueeze(-1)).view(self.meta_xadv.size())
        else:
            delta = F.tanh(self(inputs).unsqueeze(-1)).view(self.meta_xadv.size())
            self.meta_xadv = self.meta_xadv - step_size * delta
            reg = torch.mean((delta - grads.detach().sign()) * (delta - grads.detach().sign()))

        self.meta_xadv = clip_by_tensor(self.meta_xadv, x_min, x_max)
        if mode != 'eval':
            return self.meta_xadv, reg
        else:
            return self.meta_xadv


class MetaTwoStageOptimizer(nn.Module):

    def __init__(self, num_layers, hidden_size, batch_size, inputdim=3, change_point=20):
        super(MetaTwoStageOptimizer, self).__init__()
        self.meta_xadv = Variable(torch.zeros([batch_size, 3, 32, 32])).cuda()

        self.hidden_size = hidden_size
        self.inputdim = inputdim
        self.num_layers = num_layers
        self.change_point= change_point
        # parameter 1:
        self.one_linear1 = nn.Linear(self.inputdim, hidden_size)
        self.one_ln1 = LayerNorm1D(hidden_size)

        self.one_lstms1 = LayerNormLSTMCell(hidden_size, hidden_size)
        self.one_lstms2 = LayerNormLSTMCell(hidden_size, hidden_size)
        self.one_linear2 = nn.Linear(hidden_size, 1)
        self.one_linear2.weight.data.mul_(0.1)
        self.one_linear2.bias.data.fill_(0.0)

        # parameter 2:
        self.two_linear1 = nn.Linear(self.inputdim, hidden_size)
        self.two_ln1 = LayerNorm1D(hidden_size)

        self.two_lstms1 = LayerNormLSTMCell(hidden_size, hidden_size)
        self.two_lstms2 = LayerNormLSTMCell(hidden_size, hidden_size)

        self.two_linear2 = nn.Linear(hidden_size, 1)
        self.two_linear2.weight.data.mul_(0.1)
        self.two_linear2.bias.data.fill_(0.0)

    def reset_lstm(self, keep_states=False, xadv=None, use_cuda=True, device=torch.device('cuda')):
        self.meta_xadv = Variable(xadv.data).cuda()
        if keep_states:
            for i in range(self.num_layers):
                self.hx[i] = Variable(self.hx[i].data).cuda()
                self.cx[i] = Variable(self.cx[i].data).cuda()
        else:
            self.hx = []
            self.cx = []
            for i in range(self.num_layers):
                self.hx.append(Variable(torch.zeros(1, self.hidden_size)))
                self.cx.append(Variable(torch.zeros(1, self.hidden_size)))
                if use_cuda:
                    self.hx[i], self.cx[i] = self.hx[i].cuda(), self.cx[i].cuda()

    def forward(self, x, stage):
        # Gradients preprocessing
        if stage == 1:
            x = F.tanh(self.one_ln1(self.one_linear1(x)))

            # lstm1
            if x.size(0) != self.hx[0].size(0):
                self.hx[0] = self.hx[0].expand(x.size(0), self.hx[0].size(1))
                self.cx[0] = self.cx[0].expand(x.size(0), self.cx[0].size(1))
            self.hx[0], self.cx[0] = self.one_lstms1(x, (self.hx[0], self.cx[0]))
            x = self.hx[0]
            
            # lstm2
            if x.size(0) != self.hx[1].size(0):
                self.hx[1] = self.hx[1].expand(x.size(0), self.hx[1].size(1))
                self.cx[1] = self.cx[1].expand(x.size(0), self.cx[1].size(1))
            self.hx[1], self.cx[1] = self.one_lstms2(x, (self.hx[1], self.cx[1]))
            x = self.hx[1]

            x = self.one_linear2(x)
        elif stage == 2:
            x = F.tanh(self.two_ln1(self.two_linear1(x)))
            # lstm1
            if x.size(0) != self.hx[0].size(0):
                self.hx[0] = self.hx[0].expand(x.size(0), self.hx[0].size(1))
                self.cx[0] = self.cx[0].expand(x.size(0), self.cx[0].size(1))
            self.hx[0], self.cx[0] = self.two_lstms1(x, (self.hx[0], self.cx[0]))
            x = self.hx[0]
            #lstm2
            if x.size(0) != self.hx[1].size(0):
                self.hx[1] = self.hx[1].expand(x.size(0), self.hx[1].size(1))
                self.cx[1] = self.cx[1].expand(x.size(0), self.cx[1].size(1))
            self.hx[1], self.cx[1] = self.two_lstms2(x, (self.hx[1], self.cx[1]))
            x = self.hx[1]
        
            x = self.two_linear2(x)
        else:
            raise IOError
        return x.squeeze()

    def meta_update(self, x_adv, mode, idx, step_size, x_min, x_max, clip_by_tensor):
        # First we need to create a flat version of parameters and gradients
        stage = 1 if idx < self.change_point else 2
        grads = x_adv.grad
        flat_xadv = x_adv.view(-1).unsqueeze(-1)
        if self.inputdim == 3:
            flat_grads = preprocess_gradients(grads.view(-1).unsqueeze(-1))
            inputs = torch.cat((flat_grads, flat_xadv), 1)
        elif self.inputdim == 2:
            flat_grads = preprocess_gradients(grads.view(-1).unsqueeze(-1))
            inputs = flat_grads
        elif self.inputdim == 1:
            inputs = grads.clone().view(-1).unsqueeze(-1)
        else:
            raise IOError

        # Meta update itself
        if mode == "eval":
            with torch.no_grad():
                self.meta_xadv = self.meta_xadv + step_size * F.tanh(self(inputs, stage).unsqueeze(-1)).view(self.meta_xadv.size())
        else:
            delta = F.tanh(self(inputs, stage).unsqueeze(-1)).view(self.meta_xadv.size())
            self.meta_xadv = self.meta_xadv + step_size * delta
            reg = torch.mean((delta + grads.detach().sign()) * (delta + grads.detach().sign()))
            
        self.meta_xadv = clip_by_tensor(self.meta_xadv, x_min, x_max)
        if mode != 'eval':
            return self.meta_xadv, reg
        else:
            return self.meta_xadv

