import os
import datetime
import time
import cv2
import numpy as np
import argparse
import os.path as osp

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.distributed as dist
import torch.distributed.launch
import torch.optim as optim
from sklearn import metrics
import torch.nn.functional as F

# from tensorboardX import SummaryWriter

from util.util import AverageMeter, get_model_para_number, setup_seed, get_logger, get_save_path, check_makedirs

from dataloader.dataloader import get_smkd_dataloder
from model.FewVS import FewVSBuilder

from config import get_parser
cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)


def prepare_label(args, split="test"):
    # prepare one-hot label
    if split in ["test", "val"]:
        label = torch.arange(
            args.n_ways, dtype=torch.int16).repeat(args.n_queries)
    else:
        label = torch.arange(args.n_train_ways, dtype=torch.int16).repeat(
            args.n_train_queries)
    label = label.type(torch.LongTensor)
    if torch.cuda.is_available():
        label = label.cuda()
    return label

def get_model(args):
    if args.mode == 'FewVS':
        model = FewVSBuilder(args)
    else:
        raise ValueError('Dont support {}'.format(args.mode))
    if args.optim == "adam":
        optimizer = torch.optim.Adam(params=model.parameters(),
                                    lr=args.initial_lr)
    elif args.optim == "SGD":
        optimizer = optim.SGD(model.parameters(),
                            lr=args.initial_lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    elif args.optim == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.initial_lr)
    
    
    if hasattr(model, 'freeze_modules'):
        model.freeze_modules()

    get_save_path(args)
    check_makedirs(args.snapshot_path)
    check_makedirs(args.result_path)
    
    # Resume
    if args.resume:
        resume_path = osp.join(args.snapshot_path, args.resume)
        load_checkpoint(model, optimizer, resume_path, logger)

    elif args.is_test:  
        if args.mode in ['MMLM', 'clip']:
            pass
        else:
            load_checkpoint(model, None, args.test_weight, logger)

    model.cuda()
    if args.distributed:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[
                                                        args.local_rank], output_device=args.local_rank, find_unused_parameters=False)
    # Get model para.
    total_number, learnable_number = get_model_para_number(model)
    if main_process():
        print('Number of Parameters: %d ' % (total_number / 1))
        print('Number of Learnable Parameters: %d ' % (learnable_number / 1))

    time.sleep(2)
    return model, optimizer

def main_process():
    return not args.distributed or (args.distributed and (args.local_rank == 0))

def load_checkpoint(model, optimizer, weight_path, logger):    
    if os.path.isfile(weight_path):
        if main_process():
            logger.info(
                "=> loading test checkpoint '{}'".format(weight_path))
        checkpoint = torch.load(
            weight_path, map_location=torch.device('cpu'))
        
        new_param = checkpoint['state_dict']
        try:
            model.load_state_dict(new_param)
        except RuntimeError:                   # 1GPU loads mGPU model
            for key in list(new_param.keys()):
                new_param[key[7:]] = new_param.pop(key)
            model.load_state_dict(new_param, strict=False)
        if optimizer:
            optimizer.load_state_dict(checkpoint['optimizer'])
        
        if main_process():
            logger.info(
                "=> loaded checkpoint ({}) for testing".format(weight_path))
    else:
        assert False, "=> no checkpoint found at '{}'".format(weight_path)


@torch.no_grad()
def validate(model, loader, args):
    if main_process():
        logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
    
    if args.manual_seed is not None:
        args.manual_seed = args.manual_seed + args.local_rank
        setup_seed(args.manual_seed, args.seed_deterministic)
    
    batch_time = AverageMeter()
    loss_meter = AverageMeter()

    model.eval()
    end = time.time()
    criterion = torch.nn.CrossEntropyLoss()
    acc = []
    
    for idx, (images, global_labels) in enumerate(loader):
        if type(images) is list:
            images = [i.cuda(non_blocking=True) for i in images]
        else:
            images = images.cuda(non_blocking=True)
        # compute output
        meta_labels = prepare_label(args, split=loader.dataset.split)
        meta_labels = meta_labels.cuda(non_blocking=True)
        selected_labelidx = global_labels[:len(set(global_labels.tolist()))].tolist()
        selected_classes = [
            loader.dataset.label2class[idx] for idx in selected_labelidx]

        if args.mode in ['train_proj', 'FewVS']: 
            output = model(images, selected_classes, args.extract_dataset, split=loader.dataset.split)
        else:
            output = model(images, split=loader.dataset.split)
        # measure accuracy and record loss
        loss = criterion(output, meta_labels)
        loss_meter.update(loss.item(), meta_labels.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if idx % args.print_freq == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            if main_process():
                logger.info(
                    f'Test: [{idx}/{len(loader)}]\t'
                    f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                    f'Mem {memory_used:.0f}MB')
        logits = torch.argmax(output, dim=1)
        logits = logits.detach().cpu().numpy()
        meta_labels = meta_labels.detach().cpu().numpy()
        acc.append(metrics.accuracy_score(meta_labels, logits))
    acc_list = [i * 100 for i in acc]
    ci95 = 1.96 * np.std(acc_list, axis=0) / np.sqrt(len(acc_list))


    if main_process():
        logger.info(
            f' * Acc on {args.n_ways} way-{args.n_shots} shot: {np.mean(acc_list):.3f}({ci95:.3f})')
    if args.is_test and main_process():
        os.makedirs("eval_results", exist_ok=True)
        with open(os.path.join("eval_results", args.dataset + "_{}_eval_results_{}way_{}shot.txt".format(args.mode, args.n_ways, args.n_shots)), "a") as f:
            f.write(str("{}: {}({:.3f})\n".format(args.manual_seed, np.mean(acc_list), ci95)))
        f.close()
    return np.mean(acc_list)


def main():
    global args, logger
    args = get_parser()
    logger = get_logger()
    args.distributed = True if (torch.cuda.device_count() > 1 and not args.is_test) else False     
    if main_process():
        print(args)
    
    if args.distributed:
        # Initialize Process Group
        dist.init_process_group(backend='nccl')
        print('args.local_rank: ', args.local_rank)
        torch.cuda.set_device(args.local_rank)
    # =========================initilize============================
    
    if args.manual_seed is not None:
        args.manual_seed = args.manual_seed + args.local_rank
        setup_seed(args.manual_seed, args.seed_deterministic)

    if main_process():
        logger.info("=> creating model ...")

    model, optimizer = get_model(args)

# ----------------------  DATASET  ----------------------
    if main_process():
        logger.info("=> loading datasets ...")
        train_loader, val_loader, test_loader = get_smkd_dataloder(args)
# ----------------------  TEST  ----------------------
    if args.is_test:  
        root = 'weights/{}/{}'.format(args.dataset, args.backbone)
        if not os.path.exists(root):
            os.mkdir(root)
        filename = os.path.join(root, args.test_weight)
        filename = 'weights/{}/{}/{}shot.pth'.format(args.dataset, args.backbone, args.n_shots)
        state_dict = model.state_dict()
        filtered_state_dict = {key: value for key, value in state_dict.items() if "enc_t" not in key}

        torch.save({'state_dict': filtered_state_dict}, filename)
        
        validate(model, test_loader, args)
        return



if __name__ == '__main__':
    main()
    
    
    
    
    
    
    
    
    