import os
import random
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from model_FAML import TMC
from data import Hand_train, Hand_test
import warnings


warnings.filterwarnings("ignore")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def normal(args):

    if args.DatasetName == 'Hand':
        dataset_train = Hand_train()
        dataset_test  = Hand_test()
    else:
        print('Dataset not recognized.')

    num_classes = dataset_train.num_classes
    num_views = dataset_train.num_views
    dims = dataset_train.dims
    nums = torch.tensor(dataset_train.nums, dtype=torch.float32)

    num_corrects = torch.zeros(num_classes)
    num_corrects_per_view = torch.zeros(num_views, num_classes)

    train_labels = []
    for i in range(len(dataset_train)):
        _, label, _ = dataset_train[i]
        train_labels.append(label)

    num_samples_train = len(dataset_train)
    num_samples_test = len(dataset_test)

    train_index = np.arange(num_samples_train)
    test_index = np.arange(num_samples_test)
    np.random.shuffle(train_index)
    train_loader = DataLoader(Subset(dataset_train, train_index), batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(Subset(dataset_test, test_index), batch_size=args.batch_size, shuffle=False)


    model = TMC(num_classes, num_views, dims, args.beta, args.annealing_step)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    yitas = torch.ones((num_views, 1, num_classes), dtype=torch.float32).to(device)
    yita_a = torch.ones((1, num_classes), dtype=torch.float32).to(device)
    yis = torch.ones((num_views, 1, num_classes), dtype=torch.float32).to(device)
    yi_a = torch.ones((1, num_classes), dtype=torch.float32).to(device)
    nums = nums.to(device)



    print('============warm_up阶段开始============')
    for epoch in range(args.warm_up_epochs):
        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()
                train_loss_meter = AverageMeter()
                num_correct_train, num_total_train = 0, 0
                epoch_class_corrects = torch.zeros(num_classes, device=device)
                epoch_class_corrects_per_view = torch.zeros(num_views, num_classes, device=device)
                for X, Y, indexes in train_loader:
                    for v_num in range(len(X)):
                        X[v_num] = X[v_num].to(device)
                    Y = Y.to(device)
                    num_total_train += Y.size(0)
                    # refresh the optimizer
                    optimizer.zero_grad()
                    evidences, evidence_a, train_loss, loss_task, loss_con, loss_e_sigma = model(X, Y, yitas, yita_a, epoch)
                    _, predicted = torch.max(evidence_a.data, 1)
                    num_correct_train += (predicted == Y).sum().item()
                    for c in range(num_classes):
                        mask = (Y == c)
                        epoch_class_corrects[c] += (predicted[mask] == c).sum()
                    for v in range(num_views):
                        _, predicted_v = torch.max(evidences[v].data, 1)
                        for c in range(num_classes):
                            mask = (Y == c)
                            epoch_class_corrects_per_view[v][c] += (predicted_v[mask] == c).sum()

                    train_loss.backward()
                    optimizer.step()
                    train_loss_meter.update(train_loss.item())
                train_acc=num_correct_train / num_total_train
                num_corrects = epoch_class_corrects
                num_corrects_per_view = epoch_class_corrects_per_view



    print('============真正阶段开始============')
    for epoch in range(args.epochs):
        yita_a = (args.gamma * nums / (num_corrects+1)).unsqueeze(0).to(device)
        for v in range(num_views):
            yitas[v] = (args.gamma * nums / (num_corrects_per_view[v]+1)).unsqueeze(0).to(device)

        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()
                train_loss_meter = AverageMeter()
                train_task_loss_meter = AverageMeter()
                num_correct_train, num_total_train = 0, 0
                epoch_class_corrects = torch.zeros(num_classes, device=device)
                epoch_class_corrects_per_view = torch.zeros(num_views, num_classes, device=device)

                for X, Y, indexes in train_loader:
                    for v_num in range(len(X)):
                        X[v_num] = X[v_num].to(device)
                    Y = Y.to(device)
                    num_total_train += Y.size(0)
                    # refresh the optimizer
                    optimizer.zero_grad()
                    evidences, evidence_a, train_loss, loss_task, loss_con, loss_e_sigma = model(X, Y, yitas, yita_a, epoch, num_corrects, num_corrects_per_view, nums)
                    _, predicted = torch.max(evidence_a.data, 1)
                    num_correct_train += (predicted == Y).sum().item()

                    for c in range(num_classes):
                        mask = (Y == c)
                        epoch_class_corrects[c] += (predicted[mask] == c).sum()
                    for v in range(num_views):
                        _, predicted_v = torch.max(evidences[v].data, 1)
                        for c in range(num_classes):
                            mask = (Y == c)
                            epoch_class_corrects_per_view[v][c] += (predicted_v[mask] == c).sum()

                    train_loss.backward()
                    optimizer.step()
                    train_loss_meter.update(train_loss.item())
                    train_task_loss_meter.update(loss_task.item())
                train_acc=num_correct_train / num_total_train
                if epoch % 5 ==0:
                    num_corrects = epoch_class_corrects
                    num_corrects_per_view = epoch_class_corrects_per_view

    #推理阶段
    model.eval()
    heshijiao_num = 0
    heshijiao_correct = 0

    with (torch.no_grad()):
        for X, Y, _ in test_loader:
            for v_num in range(len(X)):
                X[v_num] = X[v_num].to(device)
            Y = Y.to(device)
            evidences, evidence_a, _, _, _, _ = model(X, Y, yitas, yita_a, args.epochs-1)
            _, pred_fused = torch.max(evidence_a, dim=1)
            heshijiao_correct += (pred_fused == Y).sum().item()
            heshijiao_num += Y.size(0)

    # 准确率计算
    heshijiao_acc = heshijiao_correct / heshijiao_num
    acc_allsample = heshijiao_acc

    return acc_allsample



if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--batch-size', type=int, default=256, metavar='N',
                        help='input batch size for training [default: 100]')
    parser.add_argument('--warm_up_epochs', type=int, default=20, metavar='N',
                        help='number of epochs to train [default: 500]')
    parser.add_argument('--epochs', type=int, default=200, metavar='N', #100
                        help='number of epochs to train [default: 500]')
    parser.add_argument('--annealing_step', type=int, default=50, metavar='N',
                        help='gradually increase the value of lambda from 0 to 1')
    parser.add_argument('--lr', type=float, default=0.005, metavar='LR',
                        help='learning rate')
    parser.add_argument('--gamma', type=float, default=5, metavar='gamma',
                        help='0.1/0.5/1/5/10')
    parser.add_argument('--beta', type=float, default=1, metavar='beta',
                        help='0.01/0.1/1/5/10 ')
    parser.add_argument('--DatasetName', type=str, default='Hand', metavar='N',
                        help='Hand/Scene/Animal/Caltech10120/YaleB')
    args = parser.parse_args()

    seeds = [42, 123, 456, 789, 101112]
    set_seed(42)
    acc_allsample = normal(args)
    print('acc_all: ', acc_allsample)

