import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

from data import Scene, HandWritten, PIE, CUB, CNIST, LandUse
from loss_function import get_loss
from model import CCML

import matplotlib.pyplot as plt

np.set_printoptions(precision=4, suppress=True)


def validate():
    pass


def train():
    pass

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--batch-size', type=int, default=200, metavar='N',
                        help='input batch size for training [default: 100]')
    parser.add_argument('--epochs', type=int, default=500, metavar='N',
                        help='number of epochs to train [default: 500]')
    parser.add_argument('--annealing_step', type=int, default=50, metavar='N',
                        help='gradually increase the value of lambda from 0 to 1')
    parser.add_argument('--lr', type=float, default=0.003, metavar='LR',
                        help='learning rate')
    args = parser.parse_args()

    dataset = PIE()
    num_samples = len(dataset)
    num_classes = dataset.num_classes
    num_views = dataset.num_views
    dims = dataset.dims
    delta = 1
    gamma = 1
    beta = 1
    index = np.arange(num_samples)
    np.random.shuffle(index)
    train_index, test_index = index[:int(0.8 * num_samples)], index[int(0.8 * num_samples):]
    train_loader = DataLoader(Subset(dataset, train_index), batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(Subset(dataset, test_index), batch_size=args.batch_size, shuffle=False)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = CCML(num_views, dims, num_classes, device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)

    model.to(device)

    model.train()
    for epoch in range(1, args.epochs + 1):
        if epoch % (args.epochs/10) == 0:
            print(f'epoch ====> {epoch}')
        for X, Y, indexes in train_loader:
            for v in range(num_views):
                X[v] = X[v].to(device)
            Y = Y.to(device)
            evidences, evidence_a, evidence_con, evidence_div = model(X, beta)
            loss = get_loss(evidences, evidence_a, evidence_con, evidence_div, Y, epoch, num_classes, args.annealing_step, gamma, delta, device)
            #
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # plt.bar(, evidence_vir)
    # plt.xlabel('Category')

    # draw_fig(evidence_vir,2)
    model.eval()
    num_correct, num_sample = 0, 0
    for X, Y, indexes in test_loader:
        for v in range(num_views):
            X[v] = X[v].to(device)
        Y = Y.to(device)
        with torch.no_grad():
            evidences, evidence_a, evidence_con, evidence_div = model(X, beta)
            _, Y_pre = torch.max(evidence_a, dim=1)
            flag = 0
            # for i in range(len(evidence_a)):
                # if(Y[i] == 0):
                    # flag = 1
                    # for k in [1, 0]:
                    #     draw_fig(evidences[k][i],1)
                    # draw_fig(evidence_a[i],1)
                    # print(Y[i])
                    # print(Y_pre[i])

            num_correct += (Y_pre == Y).sum().item()
            num_sample += Y.shape[0]

    print('====> acc: {:.4f}'.format(num_correct / num_sample))
