import torch
import torch.nn as nn
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import argparse
from vgg import vgg16
from transform import *
from torch.nn import CrossEntropyLoss 
from tqdm import tqdm
from data_set import *
import sys
import argparse


a1 = 0.5 
a2 = 0.5
b1 = 0.5 
b2 = 0.5

parser = argparse.ArgumentParser()
parser.add_argument('--blind', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--load', type=int, default=-1)
parser.add_argument('--bsize', type=int, default=2)
args = parser.parse_args()
load_epoch = args.load
is_blind = args.blind == 1
batch_size = args.batch_size
bsize = args.bsize
# specifying devices
device = torch.device("cuda:0")

#unblinding matrix
#blind_mat  = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32, device=device)
#unblind_mat = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32, device=device)

#blind_mat  = torch.tensor([[1.2, 0.8], [0.8, 1.2]], dtype=torch.float32, device=device)
#unblind_mat = torch.tensor([[1.5, -1], [-1, 1.5]], dtype=torch.float32, device=device)
#update_mat = torch.tensor([[0.5, 0.5], [0.5, 0.5]], dtype=torch.float32, device=device)

blind_mat   = torch.tensor([[1.2, 0.8, 0.0], [0.8, 1.2, 0.0], [1.1, 0.4, 1.0]], dtype=torch.float32, device=device)
unblind_mat = blind_mat.inverse()
update_mat  = torch.tensor([[1/2.0, 1/2.0, 0.0], [1/2.0, 1/2.0, 0.0], [1/2.0, 1/2.0, 0.0]], dtype=torch.float32, device=device).matmul(unblind_mat)
grad_mat = torch.tensor([[1/2.0, 1/2.0, 0.0], [1/2.0, 1/2.0, 0.0], [1/2.0, 1/2.0, 0.0]], dtype=torch.float32, device=device)

# data loader
train_loader = data_loader(1, batch_size)
val_loader  = data_loader(1, batch_size, is_train=False)
print("train size", len(train_loader) * batch_size, "val size", len(val_loader) * batch_size)
train_len = len(train_loader)
val_len = len(val_loader)

num_classes=10

transformed_list = []
for i in range(3):
    model = vgg16(num_classes=num_classes)
    li, blind_list = transform(model.features)
    model.add_classifier(li, blind_list)
    transformed_list.append(li)
    model.cuda()
# deep copys
copy_sequentials(transformed_list)

# optimizer
opt_list  = collect_weights(transformed_list, is_blind)

assert(len(transformed_list[0]) == len(blind_list))
# shape specifier
max_cpu_thread = 24
input_size  = (batch_size, 3, 224, 224)
output_size = (batch_size, num_classes)
y = torch.randn((2, 64, 224, 224), device=device)


# moving data to gpu
loss_func = [CrossEntropyLoss().cuda(), CrossEntropyLoss().cuda(), CrossEntropyLoss().cuda()]

# log file
filename = 'loss_log.txt'
log = open(filename, 'w')

if load_epoch >= 0:
    load_model(transformed_list[0], load_epoch)
    load_model(transformed_list[1], load_epoch)

# forward pass
for epoch in range(200):
    count  = 1
    average_loss = 0.0
    print("training epcho ", epoch)
    train_iter = iter(train_loader)
    e = 0

    zero_im = None
    zero_ta = None
    for idx in tqdm(range(int(train_len / (bsize - 1) ))):
        image = []
        target = []
        
        #try:
            
        for imagerange in range(bsize - 1):
            i, t = next(train_iter)
            i = i.cuda()
            t = t.cuda()
            image.append(i)
            target.append(t)

        if zero_im is None or zero_ta is None:
            zero_im = torch.zeros(image[0].size(), dtype=image[0].dtype, device=image[0].device)
            zero_ta = torch.zeros(target[0].size(), dtype=target[0].dtype, device=target[0].device)
        image.append(zero_im)
        target.append(zero_ta)

            
        if (len(image) != bsize):
            break
        
        #except:
        #    e = e + 1
        #    continue
        #    print(e, "corrupted images")
        
        # forward
        losses, res, node_list, loss_list, grad_list = forward(transformed_list, blind_list, blind_mat, unblind_mat, loss_func, is_blind, image, target)
        cur_loss = 0.0
        #print(losses[0].item(), losses[1].item())
        for c, loss in enumerate(losses):
            if c == len(losses) - 1:
                break
            cur_loss = cur_loss+loss.item()
        cur_loss = cur_loss / (len(losses) - 1)
        #print(cur_loss)
        average_loss =  average_loss + (cur_loss - average_loss) / count
        #print("average loss", average_loss)
        assert(len(grad_list) == len(loss_list)
                and len(loss_list) == len(node_list))
        # backwards

        group_zero_grad(transformed_list)
        backward(losses, res, node_list, loss_list, grad_list, grad_mat=grad_mat, is_blind=is_blind)

        # optimizers
        update(opt_list, update_mat, is_blind)
        count = count + 1
        

    # end of epoch
    top1_err = 0.0
    top5_err = 0.0
    print("validating epoch ", epoch)
    val_iter = iter(val_loader)
    for val_loop in tqdm(range(int(val_len / (bsize-1)))):
        image = []
        target = []
        #try:
        for imagerange in range(bsize-1):
            i, t = next(val_iter)
            i = i.cuda()
            t = t.cuda()
            image.append(i)
            target.append(t)
        
        image.append(torch.zeros(image[0].size(), dtype=image[0].dtype, device=image[0].device))
        

        if (len(image) != bsize):
            break
        #except:
        #    e = e + 1
        #    continue
        #    print(e, "corrupted images")
        res = forward_val(transformed_list, blind_list, blind_mat, unblind_mat, is_blind, image)

        pred = torch.argmax(res, dim=1)
        
        target = torch.cat(target, axis=0)
        diff = ((pred - target) == 0).float()
        
        e1 = sum(diff).item()

        top1_err = top1_err + e1

    top1_acc = top1_err / 10000.


    log.write(str(epoch))
    log.write(", average_train loss ")
    log.write(str(average_loss))
    log.write(", top1 acc ")
    log.write(str(top1_acc))
    log.write("\n")
    log.flush()
    print("Epoch ", epoch, "average train loss ", average_loss, "top 1 acc ", top1_acc)
    
    save_model(transformed_list, epoch)

