import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import model
import transform as tran
import adversarial as ad
import numpy as np
import os
import argparse
torch.set_num_threads(1)
from read_data import ImageList_aug as ImageList


torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


parser = argparse.ArgumentParser(description='PyTorch DAregre experiment')
parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
parser.add_argument('--lr', type=float, default=0.1,
                        help='init learning rate for fine-tune')
parser.add_argument('--gamma', type=float, default=0.001,
                        help='learning rate decay')
parser.add_argument('--seed', type=int, default=0,
                        help='random seed')
parser.add_argument('--label_ratio', type=float, default=0.01,
                        help='labeled data ratio')
parser.add_argument('--tradeoff', type=float, default=0.1,
                        help='tradeoff for X-model')
parser.add_argument('--labeled_path', type=str,
                        help='labeled_path')
parser.add_argument('--unlabeled_path', type=str,
                        help='unlabeled_path')
parser.add_argument('--test_path', type=str,
                        help='test_path')
args = parser.parse_args()

torch.manual_seed(args.seed)
np.random.seed(args.seed)


os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
use_gpu = torch.cuda.is_available()
if use_gpu:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

data_transforms = {
    'train': tran.rr_train(resize_size=224, crop_size=224),
    'train_aug': tran.rr_train_aug(resize_size=224, crop_size=224),
    'val': tran.rr_train(resize_size=224, crop_size=224),
    'test': tran.rr_eval(resize_size=224, crop_size=224),
}
# set dataset
batch_size = {"train": 36, "val": 36, "test": 4}
labeled_path = args.labeled_path
unlabeled_path = args.unlabeled_path
test_path = args.test_path



dsets = {"train": ImageList(open(labeled_path).readlines(), transform=data_transforms["train"], transform2 = data_transforms["train_aug"],eval = False),
         "val": ImageList(open(unlabeled_path).readlines(),transform=data_transforms["val"], transform2 = data_transforms["train_aug"],eval = False),
         "test": ImageList(open(test_path).readlines(),transform=data_transforms["test"], eval = True)}
dset_loaders = {x: torch.utils.data.DataLoader(dsets[x], batch_size=batch_size[x],
                                               shuffle=True, num_workers=0)
                for x in ['train', 'val']}
dset_loaders["test"] = torch.utils.data.DataLoader(dsets["test"], batch_size=batch_size["test"],
                                                   shuffle=False, num_workers=64)

dset_sizes = {x: len(dsets[x]) for x in ['train', 'val','test']}
dset_classes = range(1)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def Regression_test(loader, model,iter_num):
    MSE = [0, 0, 0, 0]
    MAE = [0, 0, 0, 0]
    number = 0
    with torch.no_grad():
        for (imgs, labels) in loader['test']:
            imgs = imgs.to(device)
            labels_source = labels.to(device)
            labels1 = labels_source[:, 0]
            #labels2 = labels_source[:, 1]
            labels3 = labels_source[:, 2]
            labels4 = labels_source[:, 3]
            labels1 = labels1.unsqueeze(1)
            #labels2 = labels2.unsqueeze(1)
            labels3 = labels3.unsqueeze(1)
            labels4 = labels4.unsqueeze(1)
            labels_source = torch.cat((labels1, labels3, labels4), dim=1)
            labels = labels_source.float()
            pred,_ = model(imgs)
            MSE[0] += torch.nn.MSELoss(reduction='sum')(pred[:, 0], labels[:, 0])
            MAE[0] += torch.nn.L1Loss(reduction='sum')(pred[:, 0], labels[:, 0])
            MSE[1] += torch.nn.MSELoss(reduction='sum')(pred[:, 1], labels[:, 1])
            MAE[1] += torch.nn.L1Loss(reduction='sum')(pred[:, 1], labels[:, 1])
            MSE[2] += torch.nn.MSELoss(reduction='sum')(pred[:, 2], labels[:, 2])
            MAE[2] += torch.nn.L1Loss(reduction='sum')(pred[:, 2], labels[:, 2])
            MSE[3] += torch.nn.MSELoss(reduction='sum')(pred, labels)
            MAE[3] += torch.nn.L1Loss(reduction='sum')(pred, labels)
            number += imgs.size(0)
    for j in range(4):
        MSE[j] = MSE[j] / number
        MAE[j] = MAE[j] / number
    print("\tMSE : {0},{1},{2}\n".format(MSE[0],MSE[1],MSE[2]))
    print("\tMAE : {0},{1},{2}\n".format(MAE[0], MAE[1], MAE[2]))
    print("\tMSEall : {0}\n".format(MSE[3]))
    print("\tMAEall : {0}\n".format(MAE[3]))





def inv_lr_scheduler(param_lr, optimizer, iter_num, gamma, power, init_lr=0.001, weight_decay=0.0005):
    lr = init_lr * (1 + gamma * iter_num) ** (-power)
    i = 0
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr * param_lr[i]
        param_group['weight_decay'] = weight_decay * 2
        i += 1
    return optimizer




class X_model(nn.Module):
    def __init__(self):
        super(X_model,self).__init__()
        self.model_fc = model.Resnet18Fc()
        self.classifier_layer = nn.Linear(512, 3)
        self.classifier_layer.weight.data.normal_(0, 0.01)
        self.classifier_layer.bias.data.fill_(0.0)
        self.classifier_layer = nn.Sequential(self.classifier_layer,  nn.Sigmoid())
        self.classifier_layer2 = nn.Linear(512, 3)
        self.classifier_layer2.weight.data.normal_(0, 0.01)
        self.classifier_layer2.bias.data.fill_(0.0)
        self.classifier_layer2 = nn.Sequential(self.classifier_layer2, nn.Sigmoid())
        self.grl = ad.AdversarialLayer(high=1.0)
    def forward(self,x,origin=True,unlabel=False):
        if self.training:
            feature = self.model_fc(x)
            if unlabel==False:
                if origin == True:
                    outC= self.classifier_layer(feature)
                elif origin == False:
                    outC= self.classifier_layer2(feature)
                return(outC,feature)
            else:
                feature_adv = self.grl(feature)
                if origin == True:
                    outC= self.classifier_layer(feature_adv)
                elif origin == False:
                    outC= self.classifier_layer2(feature_adv)
                return(outC,feature)
        else:
            feature = self.model_fc(x)
            outC1 = self.classifier_layer(feature)
            outC2 = self.classifier_layer2(feature)
            outC = (outC1+outC2)/2
            return(outC,feature)






model_x = X_model()
model_x = model_x.to(device)

model_x.train(True)
criterion = {"classifier": nn.MSELoss()}
optimizer_dict = [{"params": filter(lambda p: p.requires_grad, model_x.model_fc.parameters()), "lr": 0.1},
                  {"params": filter(lambda p: p.requires_grad, model_x.classifier_layer.parameters()), "lr": 1},
                  {"params": filter(lambda p: p.requires_grad, model_x.classifier_layer2.parameters()), "lr": 1}]
optimizer = optim.SGD(optimizer_dict, lr=0.1, momentum=0.9, weight_decay=0.0005, nesterov=True)
train_cross_loss = train_transfer_loss = train_total_loss = train_sigma = 0.0
len_source = len(dset_loaders["train"]) - 1
len_target = len(dset_loaders["val"]) - 1
param_lr = []
iter_source = iter(dset_loaders["train"])
iter_target = iter(dset_loaders["val"])
for param_group in optimizer.param_groups:
    param_lr.append(param_group["lr"])
test_interval = 500
num_iter = 10002
print(args)
for iter_num in range(1, num_iter + 1):
    model_x.train(True)
    optimizer = inv_lr_scheduler(param_lr, optimizer, iter_num, init_lr=args.lr, gamma=args.gamma, power=0.75,
                                 weight_decay=0.0005)
    optimizer.zero_grad()
    if iter_num % len_source == 0:
        iter_source = iter(dset_loaders["train"])
    if iter_num % len_target == 0:
        iter_target = iter(dset_loaders["val"])
    data_source = iter_source.next()
    data_target = iter_target.next()
    inputs_source, inputs2_source, labels_source = data_source
    inputs_target, inputs2_target, _ = data_target
    labels1 = labels_source[:, 0]
    labels3 = labels_source[:, 2]
    labels4 = labels_source[:, 3]
    labels1 = labels1.unsqueeze(1)
    labels3 = labels3.unsqueeze(1)
    labels4 = labels4.unsqueeze(1)
    labels_source = torch.cat((labels1,labels3,labels4),dim=1)
    labels_source = labels_source.float()
    inputs_source = inputs_source.to(device)
    inputs2_source = inputs2_source.to(device)
    inputs_target = inputs_target.to(device)
    inputs2_target = inputs2_target.to(device)
    labels = labels_source.to(device)
    outC_s, feature_s = model_x(inputs_source,origin=True,unlabel=False)
    outC2_s, feature2_s = model_x(inputs2_source,origin=False,unlabel=False)
    outC_t, feature_t = model_x(inputs_target,origin=True,unlabel=True)
    outC2_t, feature2_t = model_x(inputs2_target,origin=False,unlabel=True)
    classifier_loss = criterion["classifier"](outC_s, labels)+criterion["classifier"](outC2_s, labels)
    minimax_loss = -args.tradeoff*(torch.mean(((outC_t - outC2_t) ** 2)))
    total_loss = classifier_loss + minimax_loss
    total_loss.backward()
    optimizer.step()
    train_cross_loss += classifier_loss.item()
    train_total_loss += total_loss.item()
    if iter_num % test_interval == 0:
        print("Iter {:05d}, Average Cross Entropy Loss: {:.4f};  Average Training Loss: {:.4f}".format(
            iter_num, train_cross_loss / float(test_interval),
            train_total_loss / float(test_interval)))
        train_cross_loss  = train_total_loss
    if (iter_num % 500) == 0:
        model_x.eval()
        Regression_test(dset_loaders, model_x,iter_num)
