import torch
import torch.nn as nn
import sys

def transform(sequential):
    res = []
    control = []
    for layer in sequential:
        layer.requires_grad = True
        l = nn.Sequential(layer)
        l.requires_grad = True
        res.append(l)
        val = 1
        if isinstance(layer, nn.Conv2d):
            print("conv2d")
            val = 0
        elif isinstance(layer, nn.MaxPool2d):
            print("maxpool")
        elif isinstance(layer, nn.ReLU):
            print("relu")
        else:
            val = 1
            print("error")
        control.append(val)
    return res, control

def group_zero_grad(models):
    for layers in models:
        for model in layers:
            model.zero_grad()

def blind(mixture, bm):
    merge_len = mixture[0].size(0)
    new_mix = []
    for idx, m in enumerate(mixture):
        if idx == len(mixture) - 1:
            break
        new_mix.append(m)

    
    new_mix.append(torch.randn(mixture[0].size(), device=mixture[0].device, dtype=mixture[0].dtype))
    new_mix[-1].requires_grad = True
    concated_image = torch.cat(new_mix, axis=0)
    #concated_image = torch.cat(mixture, axis=0)
    dim0 = concated_image.size()[0] >> 1
    concated_res =  bm.matmul(concated_image.reshape(1, 3,-1)).reshape(concated_image.size())
    return torch.split(concated_res, merge_len, dim=0)

def unblind(mixture, um):
    merge_len = mixture[0].size(0)
    concated_image = torch.cat(mixture, axis=0)
    dim0 = concated_image.size()[0] >> 1
    concated_res =  um.matmul(concated_image.reshape(1, 3,-1)).reshape(concated_image.size())
    return torch.split(concated_res, merge_len, dim=0)

def group_detach(x):
    res = []
    for i in x:
        res.append(i.detach())
    return res

def collect_weights(tlist, is_blind=False):
    if is_blind == False:
        total_list = []

        for list in tlist:
            optimizer_list = []
            for sq in list:
                i = 0
                for p in sq.parameters():
                    i = i + 1
                if i == 0:
                    continue
                optimizer = torch.optim.SGD(sq.parameters(), lr=1e-2)
                optimizer_list.append(optimizer)
            total_list.append(optimizer_list)
        return total_list
    else:
        params = []
        for list in tlist:
            local_p = []
            for sq in list:
                i = 0
                for p in sq.parameters():
                    local_p.append(p)
            params.append(local_p)

        total_len = len(params[0])
        model_len = len(params)

        res_list = []
        for layer_index in range(total_len):
            p_list = []
            for model_index in range(model_len):
                p_list.append(params[model_index][layer_index])
            res_list.append(p_list)
        return res_list

def update(optimizer_list, unblind_mat, is_blind=False):
    unblineded_list = []
    if is_blind == False:
        for opts in optimizer_list:
            for opt in opts:
                opt.step()
                opt.zero_grad()
    else:
        with torch.no_grad():
            c = 0
            for p_list in optimizer_list:
                dp_list = []
                for p in p_list:
                    dp_list.append(p.grad)
                if c % 2 == 0:
                    unblinded_dp = unblind(dp_list, unblind_mat)
                else:
                    unblinded_dp = dp_list
                c += 1
                unblineded_list.append(unblinded_dp[0].cpu())

                for pi, p in enumerate(p_list):
                    p -= unblinded_dp[pi] * 1e-2
                    #print(torch.max(unblinded_dp[pi]).item())
                #if c == 1:
                #    print(unblinded_dp[0])
                #    print(unblinded_dp[1])
                #    print(sum(((unblinded_dp[0] - unblinded_dp[1]) > 0.0).float()))

                #if c >= 1:
                #    break

        return unblineded_list
def forward(transformed_list, blind_list, blind_mat, unblind_mat, loss_func, is_blind, image, target):
    prev_val = image
    if is_blind:
        prev_val = group_detach(blind(image, blind_mat))
    
    node_list = []
    loss_list = []
    grad_list = []

    list_range = len(blind_list)

    for idx, id1 in enumerate(range(list_range)):
        if idx != 0:
            prev_val = group_detach(node_list[idx-1])

        if blind_list[idx] == 1 and is_blind:
            # unblind
            prev_val = group_detach(unblind(prev_val, unblind_mat))

        for val in prev_val:
            val.requires_grad = True
            
        grad_list.append(prev_val)                    
    
        local_res = []

        for bdx in range(len(transformed_list)):
            res = transformed_list[bdx][idx](prev_val[bdx])
            local_res.append(res)
        loss_list.append(local_res)
        final_res = local_res
        if blind_list[idx] == 1 and is_blind:
            final_res = group_detach(blind(local_res, blind_mat))

        node_list.append(final_res)
    
    # loss calculation
    if is_blind:
        unblined_res = group_detach(unblind(final_res, unblind_mat))
    else:
        unblined_res = group_detach(final_res)
    
    losses = []

    for idx, unblinded_r in enumerate(unblined_res):
        unblinded_r.requires_grad = True
        
        losses.append(loss_func[idx](unblinded_r, target[idx]))

    return losses, unblined_res, node_list, loss_list, grad_list


def backward(losses, reses, node_list, loss_list, grad_list, grad_mat=None, is_blind=False, c=0):
    prev_grad = []
    for bptr, loss in enumerate(losses):
        loss.backward()
        prev_grad.append(reses[bptr].grad)
    
    for idx in range(len(loss_list)):
        #if is_blind is True:
        #    prev_grad = group_detach(unblind(prev_grad, grad_mat))
        newidx = len(node_list) - idx - 1
        next_grad = []

        for bptr in range(len(loss_list[0])):
            loss = loss_list[newidx][bptr]
            
            
            loss.backward(gradient=prev_grad[bptr])
            if idx == 0:
                del reses[bptr].grad
            else:
                del grad_list[newidx+1][bptr].grad

            next_grad.append(grad_list[newidx][bptr].grad)
        prev_grad = next_grad

    '''
    for bptr, loss in enumerate(losses):
        loss.backward()
        prev_grad = reses[bptr].grad
        torch.save(prev_grad, "saved_grad_"+str(c)+"_"+str(bptr))
    
    
        for idx in range(len(loss_list)):
            newidx = len(node_list) - idx - 1
            
            loss = loss_list[newidx][bptr]
            
            loss.backward(gradient=prev_grad)
            # cleaning
            if idx == 0:
                del reses[bptr].grad
            else:
                del grad_list[newidx+1][bptr].grad
            
            prev_grad = grad_list[newidx][bptr].grad
    '''
def forward_val(transformed_list, blind_list, blind_mat, unblind_mat, is_blind, image):
    prev_val = image
    if is_blind:
        prev_val = group_detach(blind(image, blind_mat))
        
    node_list = []

    list_range = len(blind_list)

    for idx, id1 in enumerate(range(list_range)):
        if idx != 0:
            prev_val = group_detach(node_list[idx-1])

        if blind_list[idx] == 1 and is_blind:
            # unblind
            prev_val = group_detach(unblind(prev_val, unblind_mat))

        for val in prev_val:
            val.requires_grad = True
        
    
        local_res = []

        for bdx in range(len(transformed_list)):
            res = transformed_list[bdx][idx](prev_val[bdx])
            local_res.append(res)
    
        if blind_list[idx] == 1 and is_blind:
            local_res = group_detach(blind(local_res, blind_mat))
        node_list.append(local_res)

    if is_blind:
        unblined_res = group_detach(unblind(local_res, unblind_mat))
        unblined_res.pop()
    else:
        unblined_res = group_detach(local_res)
    
    return torch.cat(unblined_res, axis=0)
