import os
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torchvision
from tqdm import tqdm
from arguments import get_args
from augmentations import get_aug
from models import get_model, get_backbone
from tools import AverageMeter
from datasets import get_dataset
from optimizers import get_optimizer, LR_Scheduler
import random

def main(args):
    train_loader = torch.utils.data.DataLoader(
        dataset=get_dataset( 
            transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs), 
            train=True, 
            **args.dataset_kwargs
        ),
        batch_size=args.eval.batch_size,
        shuffle=True,
        **args.dataloader_kwargs
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=get_dataset(
            transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), 
            train=False,
            **args.dataset_kwargs
        ),
        batch_size=args.eval.batch_size,
        shuffle=False,
        **args.dataloader_kwargs
    )


    model = get_backbone(args.model.backbone)
    classifier = nn.Linear(in_features=model.output_dim, out_features=args.eval.num_class, bias=True).to(args.device)
    print('models.outputdim', model.output_dim)

    assert args.eval_from is not None
    save_dict = torch.load(args.eval_from, map_location='cpu')
    msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True)
    
    
    # model2 = get_model(args.model).to(args.device)
    # model2 = torch.nn.DataParallel(model2)
    
    # msg2 = model2.load_state_dict({'module.'+k:v for k, v in save_dict['state_dict'].items()}, strict=True)
    
    
    # train_loader2 = torch.utils.data.DataLoader(
    #     dataset=get_dataset(
    #         transform=get_aug(train=True, **args.aug_kwargs), 
    #         train=True,
    #         **args.dataset_kwargs),
    #     shuffle=True,
    #     batch_size=args.train.batch_size,
    #     **args.dataloader_kwargs
    # )
    
    # print(msg)
    # print(msg2)
    # local_progress=tqdm(train_loader2, desc=f'Epoch {1}/{args.train.num_epochs}', disable=args.hide_progress)
    # for idx, ((images1, images2), labels) in enumerate(local_progress):
    #     data_dict = model2.forward(images1.to(args.device, non_blocking=True), images2.to(args.device, non_blocking=True))
    #     print(000)
        
    model = model.to(args.device)
    model = torch.nn.DataParallel(model)

    # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
    classifier = torch.nn.DataParallel(classifier)
    # define optimizer
    optimizer = get_optimizer(
        args.eval.optimizer.name, classifier, 
        lr=args.eval.base_lr*args.eval.batch_size/256, 
        momentum=args.eval.optimizer.momentum, 
        weight_decay=args.eval.optimizer.weight_decay)

    # define lr scheduler
    lr_scheduler = LR_Scheduler(
        optimizer,
        args.eval.warmup_epochs, args.eval.warmup_lr*args.eval.batch_size/256, 
        args.eval.num_epochs, args.eval.base_lr*args.eval.batch_size/256, args.eval.final_lr*args.eval.batch_size/256, 
        len(train_loader),
    )

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')
    trainacc_meter = AverageMeter(name='trainacc')

    # Start training
    global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        trainacc_meter.reset()
        local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.eval.num_epochs}', disable=True)
        
        for idx, (images, labels) in enumerate(local_progress):
            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)
            preds2 = classifier(feature).argmax(dim=1)
            correct = (preds2 == labels.to(args.device)).sum().item()
            trainacc_meter.update(correct/preds2.shape[0])
            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())

            lr = lr_scheduler.step()

            local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg})
        classifier.eval()
        correct, total = 0, 0
        acc_meter.reset()
        for idx, (images, labels) in enumerate(test_loader):
            with torch.no_grad():
                feature = model(images.to(args.device))
                preds = classifier(feature).argmax(dim=1)
                correct = (preds == labels.to(args.device)).sum().item()
                acc_meter.update(correct/preds.shape[0])
        print(epoch, f'Accuracy = {acc_meter.avg*100:.2f}')
    classifier.eval()
    correct, total = 0, 0
    acc_meter.reset()
    for idx, (images, labels) in enumerate(test_loader):
        with torch.no_grad():
            feature = model(images.to(args.device))
            preds = classifier(feature).argmax(dim=1)
            correct = (preds == labels.to(args.device)).sum().item()
            acc_meter.update(correct/preds.shape[0])
    print(epoch, f'Accuracy = {acc_meter.avg*100:.2f}')




if __name__ == "__main__":
    main(args=get_args())
















