import multiprocessing
import torch.multiprocessing as mp
import time
from copy import deepcopy
import numpy as np
import torch.optim as optim
import torch
import utils
import pdb
from torch.autograd import Variable
import pdb

###Code verified by Shree Atul on 041422-1606
###040522 - 0200 - if message == 'not p2p', the 'if' block still gets triggered

def local_comp(message, cl, net, data, labels, sample_type, lr, device, criterion, batch_size, n_iters, client_grads, rr_idx):
    net_local = deepcopy(net)
    net_local.train()

    if (message == 'shakespeare'): 
        optimizer = torch.optim.SGD(net_local.parameters(), lr=lr, weight_decay=0.0001)
        data = Variable(data.reshape((32,200)))
        labels = Variable(labels.reshape((32,200)))
        optimizer.zero_grad()
        hdn = net_local.init_hidden(32)
        hidden = (hdn[0].to(device), hdn[1].to(device))
        loss = 0
        for c in range(200):
            output, hidden = net_local(data[:,c], hidden)
            loss += criterion(output.view(32,-1), labels[:,c])
        loss /= 200
        loss.backward()
        optimizer.step()
        client_grads[cl] = (torch.cat([x.detach().reshape(-1) for x in net_local.parameters() if x.requires_grad != 'null'], dim=0)).squeeze(0)
        return

    optimizer = optim.SGD(net_local.parameters(), lr=lr)
    for local_iter in range(n_iters):
        batch_idx = utils.create_batch(len(data), batch_size, rr_idx, cl, sample_type)
        optimizer.zero_grad()
        outputs = net_local(data[batch_idx].to(device))
        loss = criterion(outputs, labels[batch_idx].to(device))
        loss.backward()
        optimizer.step()
    if (message.find('p2p') != -1):
        client_grads[cl] = (torch.cat([x.detach().reshape(-1) for x in net_local.parameters() if x.requires_grad != 'null'], dim=0)).squeeze(0)
    else: 
        client_grads[cl] = (torch.cat([(x-y).detach().reshape(-1) for x,y in zip(net_local.parameters(), net.parameters()) if y.requires_grad != 'null'], dim=0)).squeeze(0)
    #print(cl, utils.model_to_vec(net).sum(), utils.model_to_vec(net_local).sum())
    del net_local, outputs, loss
    torch.cuda.empty_cache()
    return


