import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import math
from timeit import default_timer as timer
import random
from trainer import Trainer
from models.model import Model
from datasets.datasets_kfold import LoadDataset
from sklearn.model_selection import KFold


def main():
    parser = argparse.ArgumentParser(description='Big model down stream')
    parser.add_argument('--seed', type=int, default=42, help='random seed (default: 0)')
    parser.add_argument('--cuda', type=int, default=6, help='cuda number (default: 1)')
    parser.add_argument('--epochs', type=int, default=50, help='number of epochs (default: 5)')
    parser.add_argument('--batch_size', type=int, default=128, help='batch size for training (default: 32)')
    parser.add_argument('--num_of_classes', type=int, default=4, help='number of classes')
    parser.add_argument('--lr', type=float, default=5e-4, help='learning rate (default: 1e-3)')
    parser.add_argument('--weight_decay', type=float, default=5e-2, help='weight decay (default: 1e-2)')
    parser.add_argument('--optimizer', type=str, default='AdamW', help='optimizer (AdamW, SGD)')
    parser.add_argument('--clip_value', type=float, default=1, help='clip_value')
    parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
    parser.add_argument('--loss_function', type=str, default='CrossEntropyLoss', help='dropout')
    parser.add_argument('--datasets_dir', type=str,
                        default='datasets_dir',
                        help='datasets_dir')
    parser.add_argument('--model_dir', type=str, default='/data/xxxxxxxx/models_weights/Big/BigMI', help='model_dir')
    parser.add_argument('--num_workers', type=int, default=16, help='num_workers')
    parser.add_argument('--label_smoothing', type=float, default=0.1, help='label_smoothing')
    parser.add_argument('--use_pretrained_weights', type=bool,
                        default=True, help='use_pretrained_weights')
    parser.add_argument('--foundation_dir', type=str,
                        default='pretrained_weights/pretrained_weights.pth',
                        help='foundation_dir')

    params = parser.parse_args()
    print(params)

    setup_seed(params.seed)
    torch.cuda.set_device(params.cuda)

    subjects = ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', ]
    kfold = KFold(n_splits=9)
    evaluations = []
    for train_index, test_index in kfold.split(subjects):
        load_dataset = LoadDataset(params, train_index, test_index)
        data_loader = load_dataset.get_data_loader()
        model = Model(params)
        t = Trainer(params, data_loader, model)
        evaluation_best, cm_best = t.train()
        print(evaluation_best)
        print(cm_best)
        evaluations.append(evaluation_best)

    mean = np.mean(np.array(evaluations), axis=0)
    std = np.std(np.array(evaluations), axis=0)
    print(mean, std)



def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


if __name__ == '__main__':
    main()
