import torch
import torch.nn as nn
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import argparse
from resnet import resnet50
from transform import *
from torch.nn import CrossEntropyLoss 
from tqdm import tqdm
from data_set import *
import sys
import argparse

# arguments parsers
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=32)
args = parser.parse_args()

batch_size = args.batch_size
bsize  = 2
device = torch.device("cuda:0")
# data loader
train_loader = data_loader(1, batch_size)
val_loader  = data_loader(1, batch_size, is_train=False)
#print("train size", len(train_loader) * batch_size, "val size", len(val_loader) * batch_size)

loss_fn = [CrossEntropyLoss().cuda(), CrossEntropyLoss().cuda()]

num_classes=10

blind_mat  = torch.tensor([[1.2, 0.8], [0.8, 1.2]], dtype=torch.float32, device=device)
unblind_mat = torch.tensor([[1.5, -1], [-1, 1.5]], dtype=torch.float32, device=device)
update_mat = torch.tensor([[0.5, 0.5], [0.5, 0.5]], dtype=torch.float32, device=device)

model = resnet50(pretrained=False, unblind_mat=unblind_mat, blind_mat=blind_mat, is_blind=True)
model_tar = resnet50(pretrained=False, unblind_mat=unblind_mat, blind_mat=blind_mat, is_blind=True)
model = model.cuda()
model_tar = model_tar.cuda()
print(model)
sys.exit(0)
model_tar.load_state_dict(model.state_dict())

seq_list = []
blind_list = []
skip_list = []
skip_blind = []

model.build_list()
model_tar.build_list()
model.merge_list([model_tar])

opt_list = collect_weights([[model], [model_tar]], is_blind=True)

train_len = len(train_loader)

i = 0

for epoch in range(50):
    train_iter = iter(train_loader)

    for idx in tqdm(range(int(train_len / bsize))):
        image = []
        target = []
        
        #try:
            
        for imagerange in range(bsize):
            i, t = next(train_iter)
            i = i.cuda()
            t = t.cuda()
            image.append(i)
            target.append(t)
        
        #y_tar = model_tar(image)
        y_split = model.custom_forward(image)

        
        y_split = group_detach(y_split)
        y_split[0].requires_grad = True
        y_split[1].requires_grad = True

        cur_loss = 0
        loss_list = []
        for pt, loss in enumerate(loss_fn):
            loss = loss(y_split[pt], target[pt])
            loss_list.append(loss)
       
        model.zero_grad()
        model_tar.zero_grad()

        for loss in loss_list:
            loss.backward()

        #loss1.backward()
        model.custom_backward([y_split[0].grad, y_split[1].grad])
        
        update(opt_list, update_mat, is_blind=True)
        #p_me_list = []

        #name_list = []
        #for name, param in model.named_parameters():
        #    name_list.append(name)
        #    p_me_list.append(param)

        #p_tar_list = []
        #for param in model_tar.parameters():
        #    p_tar_list.append(param)

        #print(len(p_tar_list))
        #print("gradient diffs come from")
        #for idx in range(len(p_me_list)):
        #    print(idx, torch.sum(((p_me_list[idx].grad - p_tar_list[idx].grad) != 0.0).float()).item())
        
        #print(name_list[35], name_list[51], name_list[60],name_list[69])
        #print(p_me_list[0].size(), p_me_list[51].size(), p_me_list[60].size(),p_me_list[69].size())
        #print(torch.max(p_me_list[0]), torch.max(p_me_list[51]), torch.max(p_me_list[60]),torch.max(p_me_list[69]), torch.max(p_me_list[160]))
        
        #if i >= 0:
        #    break

        i = i + 1
    top1_err = 0.0
    for image, target in tqdm(val_loader):
        image = image.cuda()
        target= target.cuda()
        #y_pred = model.custom_forward([image])
        y_pred = model_tar(image)
        pred = torch.argmax(y_pred, dim=1)
        diff = ((pred - target) == 0).float()
            
        e1 = sum(diff).item()

        top1_err = top1_err + e1

    top1_acc = top1_err / 10000.
    print(top1_acc)