import torch
from network import Network
from metric import valid
from torch.utils.data import Dataset
import numpy as np
import argparse
import random
from loss import Loss
from dataloader import load_data
import os

parser = argparse.ArgumentParser(description='train')
parser.add_argument('--dataset', default="Caltech7")
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument("--lamda2", default=0.1)
parser.add_argument("--lamda3", default=1.0)
parser.add_argument("--learning_rate", default=0.0003)
parser.add_argument("--weight_decay", default=0.)
parser.add_argument("--temperature_f", default=0.5)
parser.add_argument("--temperature_l", default=0.4)
parser.add_argument("--epochs", default=50)
parser.add_argument("--feature_dim", default=512)
args = parser.parse_args()
device = torch.device("cuda:{}".format(1) if
                      torch.cuda.is_available() else "cpu")

for iter in range(0, 1):
    print("---------------------------------------------------------")
    print("ROUND:", iter)
    print('dataset: ', args.dataset)
    print('learining_rate: ', args.learning_rate)
    print(args.lamda2, " ", args.lamda3)
    seed = 10
    print("seed:", seed)

    def setup_seed(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.backends.cudnn.deterministic = True


    setup_seed(seed)

    dataset, dims, view, data_size, class_num, ts = load_data(args.dataset)
    max_value = max(dims)
    max_view = max(i for i, v in enumerate(dims) if v == max_value)
    print("max_view:", max_view)

    data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=False,
            drop_last=True,
        )

    def contrastive_train(epoch, max_view):
        tot_loss = 0.
        mse = torch.nn.MSELoss()
        for batch_idx, (xs, _, _) in enumerate(data_loader):
            for v in range(view):
                xs[v] = xs[v].to(device)
            optimizer.zero_grad()
            hs, xrs, zs, zs_pre, zs_pre_align, hs_align = model(xs, max_view)
            loss_list = []

            t1 = hs_align[max_view]
            for v in range(view):
                if v != max_view:
                    t2 = hs_align[v]
                    loss_list.append(args.lamda3 * criterion.forward_feature(
                        t1, t2))
                loss_list.append(mse(xs[v], xrs[v]))
            C = model.get_C()
            loss_list.append(args.lamda2 * criterion.prototype_dif2(C))

            loss = sum(loss_list)
            loss.backward()
            optimizer.step()
            tot_loss += loss.item()
        if (epoch % 5) == 0:
            print('Epoch {}'.format(epoch), 'Loss:{:.6f}'.format(tot_loss/len(data_loader)))

    accs = []
    nmis = []
    purs = []
    if not os.path.exists('./models'):
        os.makedirs('./models')
    T = 1
    t_max = 0
    t_epoch = 0
    for i in range(T):
        setup_seed(seed)
        model = Network(view, dims, args.feature_dim, class_num,
                        device, args.batch_size)
        model = model.to(device)

        optimizer = torch.optim.Adam([
            {'params':[model.C, *model.feature_contrastive_module.parameters()],
             'lr':args.learning_rate * 1},
            {'params':(list(model.encoders.parameters()) +
                       list(model.decoders.parameters())
                       ), 'lr': args.learning_rate * 1}
        ])

        criterion = Loss(args.batch_size, class_num, args.temperature_f, args.temperature_l, device).to(device)

        epoch = 1
        while epoch <= args.epochs:
            contrastive_train(epoch, max_view)
            if epoch == args.epochs:
            # if epoch % 10 == 0:
                acc, nmi, pur, ari = valid(model, device, dataset,
                                      view, data_size, class_num,
                                      max_view, epoch, ts)
            epoch += 1
