import torch
import torch.nn as nn
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import argparse
import copy
from vgg import vgg16
from transform import *
from torch.nn import CrossEntropyLoss 
from tqdm import tqdm
from data_set import *
import sys
import argparse

c = 0


a1 = 0.5 
a2 = 0.5
b1 = 0.5 
b2 = 0.5

parser = argparse.ArgumentParser()
parser.add_argument('--blind', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--load', type=int, default=-1)
parser.add_argument('--bsize', type=int, default=2)
args = parser.parse_args()
load_epoch = args.load
is_blind = args.blind == 1
batch_size = args.batch_size
bsize = args.bsize
# specifying devices
device = torch.device("cuda:0")

#unblinding matrix
#blind_mat  = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32, device=device)
#unblind_mat = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32, device=device)

#blind_mat  = torch.tensor([[1.2, 0.8], [0.8, 1.2]], dtype=torch.float32, device=device)
#unblind_mat = torch.tensor([[1.5, -1], [-1, 1.5]], dtype=torch.float32, device=device)
#update_mat = torch.tensor([[0.5, 0.5], [0.5, 0.5]], dtype=torch.float32, device=device)

blind_mat   = torch.tensor([[1.2, 0.8, 1.0], [0.8, 1.2, 1.0], [1.1, 0.4, 0.6]], dtype=torch.float32, device=device)
unblind_mat = blind_mat.inverse()
update_mat  = torch.tensor([[1/2.0, 1/2.0, 0.0], [1/2.0, 1/2.0, 0.0], [1/2.0, 1/2.0, 0.0]], dtype=torch.float32, device=device).matmul(unblind_mat)

blind_mat1   = torch.tensor([[1.2, 0.8, 0.0], [0.8, 1.2, 0.0], [1.1, 0.4, 0.6]], dtype=torch.float32, device=device)
unblind_mat1 = blind_mat1.inverse()
update_mat1  = torch.tensor([[1/2.0, 1/2.0, 0.0], [1/2.0, 1/2.0, 0.0], [1/2.0, 1/2.0, 0.0]], dtype=torch.float32, device=device).matmul(unblind_mat1)


# data loader
train_loader = data_loader(1, batch_size)
val_loader  = data_loader(1, batch_size, is_train=False)
print("train size", len(train_loader) * batch_size, "val size", len(val_loader) * batch_size)
train_len = len(train_loader)
val_len = len(val_loader)

num_classes=10

transformed_list = []
for i in range(6):
    model = vgg16(num_classes=num_classes)
    li, blind_list = transform(model.features)
    model.add_classifier(li, blind_list)
    transformed_list.append(li)
    model.cuda()
# deep copys
copy_sequentials(transformed_list)

transformed_list1 = [transformed_list[3], transformed_list[4], transformed_list[5]]
transformed_list.pop()
transformed_list.pop()
transformed_list.pop()

# optimizer
opt_list  = collect_weights(transformed_list, is_blind)
opt_list1 = collect_weights(transformed_list1,is_blind)

assert(len(transformed_list[0]) == len(blind_list))
# shape specifier
max_cpu_thread = 24
input_size  = (batch_size, 3, 224, 224)
output_size = (batch_size, num_classes)
y = torch.randn((2, 64, 224, 224), device=device)

torch.save(transformed_list[0][-1].state_dict(), "layer1")

# moving data to gpu
loss_func = [CrossEntropyLoss().cuda(), CrossEntropyLoss().cuda(), CrossEntropyLoss().cuda()]
loss_func1 = [CrossEntropyLoss().cuda(), CrossEntropyLoss().cuda(), CrossEntropyLoss().cuda()]

# log file
filename = 'loss_log.txt'
log = open(filename, 'w')

if load_epoch >= 0:
    load_model(transformed_list[0], load_epoch)
    load_model(transformed_list[1], load_epoch)

# forward pass
for epoch in range(200):
    count  = 1
    average_loss = 0.0
    average_loss1= 0.0
    print("training epcho ", epoch)
    train_iter = iter(train_loader)
    e = 0
    zero_im = None
    zero_ta = None
    for idx in tqdm(range(int(train_len / (bsize - 1) ))):
        image = []
        target = []
        
        #try:
            
        for imagerange in range(bsize - 1):
            i, t = next(train_iter)
            i = i.cuda()
            t = t.cuda()
            image.append(i)
            target.append(t)
            torch.save(i, "img"+str(imagerange))
            torch.save(t, "tar"+str(imagerange))
            

        if zero_im is None or zero_ta is None:
            zero_im = torch.zeros(image[0].size(), dtype=image[0].dtype, device=image[0].device)
            zero_ta = torch.zeros(target[0].size(), dtype=target[0].dtype, device=target[0].device)
        image.append(zero_im)
        target.append(zero_ta)

        image1 = copy.deepcopy(image)
        target1= copy.deepcopy(target)
            
        if (len(image) != bsize):
            break
        
            

        losses, res, node_list, loss_list, grad_list = forward(transformed_list, blind_list, blind_mat, unblind_mat, loss_func, is_blind, image, target)
        losses1, res1, node_list1, loss_list1, grad_list1 = forward(transformed_list1, blind_list, blind_mat1, unblind_mat1, loss_func1, is_blind, image1, target1)
        cur_loss = 0.0
        #print(res, res1)
        # print(losses[0].item(), losses[1].item())
        for c, loss in enumerate(losses):
            if c == len(losses) - 1:
                break
            cur_loss = cur_loss+loss.item()
        cur_loss = cur_loss / (len(losses) - 1)
        cur_loss1= 0.0
        for c, loss in enumerate(losses1):
            if c == len(losses1) - 1:
                break
            cur_loss1 = cur_loss1+loss.item()
        cur_loss1 = cur_loss1 / (len(losses1) - 1)
        print("cur loss ", cur_loss)
        print("cur loss ", cur_loss1)
        
        average_loss =  average_loss + (cur_loss - average_loss) / count
        average_loss1 =  average_loss1 + (cur_loss1 - average_loss1) / count
        print("average loss", average_loss)
        print("average loss", average_loss1)
        assert(len(grad_list) == len(loss_list)
                and len(loss_list) == len(node_list))
        
        # backwards
        group_zero_grad(transformed_list)
        group_zero_grad(transformed_list1)
        backward(losses, res, node_list, loss_list, grad_list, 0)
        backward(losses1, res1, node_list1, loss_list1, grad_list1, 1)
        
        for bx, inputs in enumerate(loss_list[len(loss_list)-2]):
            torch.save(inputs, "oi_0_"+str(bx))
        
        for bx, inputs in enumerate(loss_list1[len(loss_list)-2]):
            torch.save(inputs, "oi_1_"+str(bx))

        for bx, inputs in enumerate(grad_list[-1]):
            torch.save(inputs, "linear_input_0_"+str(bx))

        for bx, inputs in enumerate(grad_list1[-1]):
            torch.save(inputs, "linear_input_1_"+str(bx))

        print(grad_list[-1][0].size())

        # optimizers
        grads = update(opt_list, update_mat, is_blind)
        grads1 = update(opt_list1,update_mat1,is_blind)
        count = count + 1
        
        with torch.no_grad():
            for ix, grad in enumerate(grads):
                print(torch.sum(torch.abs(grad - grads1[ix])))
        
        sys.exit(0)


    # end of epoch
    top1_err = 0.0
    top5_err = 0.0
    print("validating epoch ", epoch)
    val_iter = iter(val_loader)
    for val_loop in tqdm(range(int(val_len / (bsize-1)))):
        image = []
        target = []
        #try:
        for imagerange in range(bsize-1):
            i, t = next(val_iter)
            i = i.cuda()
            t = t.cuda()
            image.append(i)
            target.append(t)
        
        image.append(torch.zeros(image[0].size(), dtype=image[0].dtype, device=image[0].device))
        

        if (len(image) != bsize):
            break
        #except:
        #    e = e + 1
        #    continue
        #    print(e, "corrupted images")
        res = forward_val(transformed_list, blind_list, blind_mat, unblind_mat, is_blind, image)

        pred = torch.argmax(res, dim=1)
        
        target = torch.cat(target, axis=0)
        diff = ((pred - target) == 0).float()
        
        e1 = sum(diff).item()

        top1_err = top1_err + e1

    top1_acc = top1_err / 10000.


    log.write(str(epoch))
    log.write(", average_train loss ")
    log.write(str(average_loss))
    log.write(", top1 acc ")
    log.write(str(top1_acc))
    log.write("\n")
    log.flush()
    print("Epoch ", epoch, "average train loss ", average_loss, "top 1 acc ", top1_acc)
    
    if epoch % 10 == 0:
        save_model(transformed_list[0], epoch)
