import numpy as np
import torch
from torch import nn

import utils


def compute_conv_output_size(Lin, kernel_size, stride=1, padding=0, dilation=1):
    return int(np.floor((Lin + 2 * padding - dilation * (kernel_size - 1) - 1) / float(stride) + 1))


class PathNetAlexNet(torch.nn.Module):
    def __init__(self, inputsize, taskcla, nhid, args):
        super().__init__()

        ncha, size, _ = inputsize
        self.taskcla = taskcla
        self.ntasks = len(self.taskcla)

        # """
        # Better config found by us
        # expand_factor = 0.258  # match num params
        expand_factor = args.expand_factor  # match num params
        # self.N = 3
        self.N = args.N
        # self.M = 16
        self.M = args.M
        # """
        self.L = 5

        self.bestPath = -1 * np.ones((self.ntasks, self.L, self.N),
                                     dtype=np.int)  # we need to remember this between the tasks

        # pdrop1 = 0.2
        pdrop1 = args.pdrop1
        # pdrop2 = 0.5
        pdrop2 = args.pdrop2

        self.sizec1 = int(expand_factor * 64)
        self.sizec2 = int(expand_factor * 128)
        self.sizec3 = int(expand_factor * 256)
        self.sizefc1 = int(expand_factor * nhid)
        self.sizefc2 = int(expand_factor * nhid)

        self.c1 = nn.ModuleList()
        self.c2 = nn.ModuleList()
        self.c3 = nn.ModuleList()
        self.fc1 = nn.ModuleList()
        self.fc2 = nn.ModuleList()
        for j in range(self.M):
            c1 = nn.Conv2d(ncha, self.sizec1, kernel_size=size // 8)
            self.c1.append(c1)
            s = compute_conv_output_size(size, size // 8)
            s = s // 2

            c2 = nn.Conv2d(self.sizec1, self.sizec2, kernel_size=size // 10)
            self.c2.append(c2)
            s = compute_conv_output_size(s, size // 10)
            s = s // 2

            c3 = nn.Conv2d(self.sizec2, self.sizec3, kernel_size=2)
            self.c3.append(c3)
            s = compute_conv_output_size(s, 2)
            s = s // 2

            fc1 = nn.Linear(self.sizec3 * s ** 2, self.sizefc1)
            self.fc1.append(fc1)

            fc2 = nn.Linear(self.sizefc1, self.sizefc2)
            self.fc2.append(fc2)
        # endfor

        self.last = torch.nn.ModuleList()
        for t, ncls in self.taskcla:
            self.last.append(torch.nn.Linear(self.sizefc2, ncls))
        # endfor

        self.relu = torch.nn.ReLU()
        self.drop1 = torch.nn.Dropout(pdrop1)
        self.drop2 = torch.nn.Dropout(pdrop2)
        self.maxpool = nn.MaxPool2d(2)

        print('AlexNet PathNet')
        print('pdrop1: ', pdrop1)
        print('pdrop2: ', pdrop2)

        return
    # enddef

    def forward(self, x, t, P=None):
        if P is None:
            P = self.bestPath[t]
        # endif

        h = x

        h_pre = self.maxpool(self.drop1(self.relu(self.c1[P[0, 0]](h))))
        for j in range(1, self.N):
            h_pre += self.maxpool(self.drop1(self.relu(self.c1[P[0, j]](h))))
        # endfor
        h = h_pre

        h_pre = self.maxpool(self.drop1(self.relu(self.c2[P[1, 0]](h))))
        for j in range(1, self.N):
            h_pre += self.maxpool(self.drop1(self.relu(self.c2[P[1, j]](h))))
        # endfor
        h = h_pre

        h_pre = self.maxpool(self.drop2(self.relu(self.c3[P[2, 0]](h))))
        for j in range(1, self.N):
            h_pre += self.maxpool(self.drop2(self.relu(self.c3[P[2, j]](h))))
        # endfor
        h = h_pre

        h = h.view(h.shape[0], -1)

        h_pre = self.drop2(self.relu(self.fc1[P[3, 0]](h)))
        for j in range(1, self.N):
            h_pre += self.drop2(self.relu(self.fc1[P[3, j]](h)))
        # endfor
        h = h_pre

        h_pre = self.drop2(self.relu(self.fc2[P[4, 0]](h)))
        for j in range(1, self.N):
            h_pre += self.drop2(self.relu(self.fc2[P[4, j]](h)))
        # endfor
        h = h_pre

        y = []
        for t, i in self.taskcla:
            y.append(self.last[t](h))
        # endfor

        return y
    # enddef

    def unfreeze_path(self, t, Path):
        # freeze modules not in path P and the ones in bestPath paths for the previous tasks
        for i in range(self.M):
            self.unfreeze_module(self.c1, i, Path[0, :], self.bestPath[0:t, 0, :])
            self.unfreeze_module(self.c2, i, Path[1, :], self.bestPath[0:t, 1, :])
            self.unfreeze_module(self.c3, i, Path[2, :], self.bestPath[0:t, 2, :])
            self.unfreeze_module(self.fc1, i, Path[3, :], self.bestPath[0:t, 3, :])
            self.unfreeze_module(self.fc2, i, Path[4, :], self.bestPath[0:t, 4, :])
        # endfor

        return
    # enddef

    def unfreeze_module(self, layer, i, Path, bestPath):
        if (i in Path) and (i not in bestPath):  # if the current module is in the path and not in the bestPath
            utils.set_req_grad(layer[i], True)
        else:
            utils.set_req_grad(layer[i], False)
        # endif

        return
    # enddef
