import argparse
import torch
# torch.set_printoptions(profile="full")
import os
import random
import matplotlib.pyplot as plt
import numpy as np


from dassl.utils import setup_logger, set_random_seed, collect_env_info
from configs.my_default_config.my_default import get_cfg_default
from dassl.engine import build_trainer

# custom
import datasets.oxford_pets
import datasets.oxford_flowers
import datasets.fgvc_aircraft
import datasets.dtd
import datasets.eurosat
import datasets.stanford_cars
import datasets.food101
import datasets.sun397
import datasets.caltech101
import datasets.ucf101
import datasets.imagenet

import trainers.mytrainer
import trainers.hhzsclip


def print_args(args, cfg):
    print('***************')
    print('** Arguments **')
    print('***************')
    optkeys = list(args.__dict__.keys())
    optkeys.sort()
    for key in optkeys:
        print('{}: {}'.format(key, args.__dict__[key]))
    print('************')
    print('** Config **')
    print('************')
    print(cfg)

def reset_cfg(cfg, args):
    if args.root:
        cfg.DATASET.ROOT = args.root

    if args.output_dir:
        cfg.OUTPUT_DIR = args.output_dir

    if args.resume:
        cfg.RESUME = args.resume

    if args.seed:
        cfg.SEED = args.seed

    if args.source_domains:
        cfg.DATASET.SOURCE_DOMAINS = args.source_domains

    if args.target_domains:
        cfg.DATASET.TARGET_DOMAINS = args.target_domains

    if args.transforms:
        cfg.INPUT.TRANSFORMS = args.transforms

    if args.trainer:
        cfg.TRAINER.NAME = args.trainer

    if args.backbone:
        cfg.MODEL.BACKBONE.NAME = args.backbone

    if args.head:
        cfg.MODEL.HEAD.NAME = args.head


def extend_cfg(cfg,args):
    """
    Add new config variables.

    E.g.
        from yacs.config import CfgNode as CN
        cfg.TRAINER.MY_MODEL = CN()
        cfg.TRAINER.MY_MODEL.PARAM_A = 1.
        cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
        cfg.TRAINER.MY_MODEL.PARAM_C = False
    """
    from yacs.config import CfgNode as CN

    cfg.TRAINER.MYTrainer = CN()
    cfg.TRAINER.MYTrainer.N_CTX = 16  # number of context vectors
    cfg.TRAINER.MYTrainer.CSC = False  # class-specific context
    cfg.TRAINER.MYTrainer.CTX_INIT = ""  # initialization words
    cfg.TRAINER.MYTrainer.PREC = "fp16"  # fp16, fp32, amp
    cfg.TRAINER.MYTrainer.CLASS_TOKEN_POSITION = "end"  # 'middle' or 'end' or 'front'
    cfg.TRAINER.MYTrainer.FP = args.fp  # #false positive training samples per class
    cfg.TRAINER.MYTrainer.FP_TYPE = args.fp_type  # #false positive training samples per class
    cfg.TRAINER.MYTrainer.TOPK = args.topk
    cfg.TRAINER.MYTrainer.CLEAN_TYPE = args.clean_type
    cfg.TRAINER.MYTrainer.ENCODE_TYPE = args.encode_type
    

def setup_cfg(args):
    cfg = get_cfg_default()
    extend_cfg(cfg, args)

    # 1. From the dataset config file
    if args.dataset_config_file:
        cfg.merge_from_file(args.dataset_config_file)

    # 2. From the method config file
    if args.config_file:
        cfg.merge_from_file(args.config_file)

    # 3. hh config
    if args.hh_config_file:
        cfg.merge_from_file(args.hh_config_file)

    # 3. From input arguments
    reset_cfg(cfg, args)

    # 4. From optional input arguments
    cfg.merge_from_list(args.opts)

    cfg.freeze()

    return cfg


def main(args):
    cfg = setup_cfg(args)
    if cfg.SEED >= 0:
        print('Setting fixed seed: {}'.format(cfg.SEED))
        set_random_seed(cfg.SEED)
    setup_logger(cfg.OUTPUT_DIR)

    if torch.cuda.is_available() and cfg.USE_CUDA:
        torch.backends.cudnn.benchmark = True

    print_args(args, cfg)
    print('Collecting env info ...')
    print('** System info **\n{}\n'.format(collect_env_info()))

    trainer_list = []
    trainer = build_trainer(cfg)
    if args.model_dir:
        trainer.load_model_by_id(args.model_dir, epoch=args.load_epoch, model_id=i)
    trainer_list.append(trainer)


    prob_end = []
    results_dict = {}
    results = []


    if cfg.TRAINER.MYTrainer.FP_TYPE == 'sym':
        fp_desc = 'fp'+str(cfg.TRAINER.MYTrainer.FP)
    else:
        fp_desc = 'fp'+str(cfg.TRAINER.MYTrainer.FP)+'_asym'
    
    # for seed in range(1, 6):
    for seed in range(1, 2):
        try:



            print('./analysis_results_test/{}/{}/{}/50_{}_{}_random_init_end_encode_type_{}_clean_type_{}/test_logits.pt'.format(cfg.DATASET.NAME, fp_desc, cfg.MODEL.BACKBONE.NAME, seed, cfg.DATASET.NUM_SHOTS, args.encode_type, args.clean_type))

            prob = torch.load('./analysis_results_test/{}/{}/{}/50_{}_{}_random_init_end_encode_type_{}_clean_type_{}/test_logits.pt'.format(cfg.DATASET.NAME, fp_desc, cfg.MODEL.BACKBONE.NAME, seed, cfg.DATASET.NUM_SHOTS, args.encode_type, args.clean_type))

            prob_end.append(prob)
            res = trainer_list[0].test_with_existing_logits(prob, encode_type=args.encode_type)
            results.append(res['accuracy'])
            trainer_list[0].evaluator.reset()
        except:
            print(f'loss {seed}')
            exit()
    
    AccAvg = np.mean(results)
    AccStd = np.std(results)

    print(len(prob_end), ' shots ensemble')
    prob_test = sum(prob_end) / len(prob_end)
    results_end = trainer_list[0].test_with_existing_logits(prob_test, encode_type=args.encode_type)
    AccEns = results_end['accuracy']

    save_path = './analysis_results_test/{}/{}/{}/50_{}_{}_random_init_end_encode_type_{}_clean_type_{}/'.format(cfg.DATASET.NAME, fp_desc, cfg.MODEL.BACKBONE.NAME, seed, cfg.DATASET.NUM_SHOTS, args.encode_type, args.clean_type)
    with open(os.path.join(save_path, 'ensemble_accuracy_result.txt'), "w") as f:
        print(f'Avg: {AccAvg:.4f} Std: {AccStd:.4f} ensemble: {AccEns:.4f}', file=f)
        print(f'Acc statistics: {results}', file=f)
    print(f'Avg: {AccAvg:.4f} Std: {AccStd:.4f} ensemble: {AccEns:.4f}')
    # results_print = [f'{val:.3f}' for val in results]
    results_print = ' '.join(f'{val:.3f}' for val in results)
    print(f'Acc statistics: {results_print}')
    results_dict[len(prob_end)] = results_end['accuracy']


    architecture = args.config_file.split('/')[-1].split('.')[0].split('_')[0]
    dataset = cfg.DATASET.NAME.split('MY')[1]
    runtype = args.tag.split('_')[-1]
    save_file_dir = f'./results/{dataset}/{architecture}_{args.fp}_{args.fp_type}/{runtype}/'
    os.makedirs(save_file_dir, exist_ok=True)
    with open(f'{save_file_dir}/results.txt', 'w') as f:
        print(f'{results_print} {AccAvg:.3f} {AccStd:.4f}', file=f)




if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--root', type=str, default='', help='path to dataset')
    parser.add_argument(
        '--output-dir', type=str, default='', help='output directory'
    )
    parser.add_argument(
        '--resume',
        type=str,
        default='',
        help='checkpoint directory (from which the training resumes)'
    )
    parser.add_argument(
        '--seed',
        type=int,
        default=-1,
        help='only positive value enables a fixed seed'
    )
    parser.add_argument(
        '--encode-type',
        type=str,
        default='none',
        help='Encode type'
    )

    parser.add_argument(
        '--clean-type',
        type=str,
        default='none',
        help='Cleansing type, [none, pond, gce]'
    )
    parser.add_argument(
        '--source-domains',
        type=str,
        nargs='+',
        help='source domains for DA/DG'
    )
    parser.add_argument(
        '--target-domains',
        type=str,
        nargs='+',
        help='target domains for DA/DG'
    )
    parser.add_argument(
        '--transforms', type=str, nargs='+', help='data augmentation methods'
    )
    parser.add_argument(
        '--config-file', type=str, default='', help='path to config file'
    )
    parser.add_argument(
        '--dataset-config-file',
        type=str,
        default='',
        help='path to config file for dataset setup'
    )
    parser.add_argument(
        '--hh-config-file', type=str, default='', help='path to config file'
    )
    parser.add_argument(
        '--trainer', type=str, default='', help='name of trainer'
    )
    parser.add_argument(
        '--backbone', type=str, default='', help='name of CNN backbone'
    )
    parser.add_argument('--head', type=str, default='', help='name of head')
    parser.add_argument(
        '--eval-only', action='store_true', help='evaluation only'
    )
    parser.add_argument(
        '--model-dir',
        type=str,
        default='',
        help='load model from this directory for eval-only mode'
    )
    parser.add_argument(
        '--load-epoch',
        type=int,
        help='load model weights at this epoch for evaluation'
    )
    parser.add_argument(
        '--fp',
        type=int,
        default=0,
        help='portion of false positive training samples per class'
    )
    parser.add_argument(
        '--fp-type', type=str, default='sym', help='Noisy type (sym,asym,inst)'
    )
    parser.add_argument(
        '--topk',
        type=int,
        default=3,
        help='Top-k selection'
    )
    parser.add_argument(
        '--no-train', action='store_true', help='do not call trainer.train()'
    )
    parser.add_argument(
        'opts',
        default=None,
        nargs=argparse.REMAINDER,
        help='modify config options using the command-line'
    )
    parser.add_argument(
        '--tag',
        type=str,
        default='',
        help='tag for method'
    )

    args = parser.parse_args()
    main(args)
