import os
import argparse
import numpy as np
import warnings
from sklearn.exceptions import ConvergenceWarning

import torch
from torch.utils.data import DataLoader

from algorithms.wrapper import get_algorithm
from models.get_model import get_model
from set_model import MeanPooling, get_set_model
from data import EpisodicDataset, CustomTransform

# simclr
SIMCLR = ['ANONYMIZED',
          'ANONYMIZED',
          'ANONYMIZED',
          'ANONYMIZED',
          'ANONYMIZED']

MOCO = ['ANONYMIZED',
        'ANONYMIZED',
        'ANONYMIZED',
        'ANONYMIZED',
        'ANONYMIZED']

# byol
BYOL = ['ANONYMIZED',
        'ANONYMIZED',
        'ANONYMIZED',
        'ANONYMIZED',
        'ANONYMIZED']

# barlow twins
BARLOW_TWINS = ['ANONYMIZED',
                'ANONYMIZED',
                'ANONYMIZED',
                'ANONYMIZED',
                'ANONYMIZED']

# set simclr conv5_64 + SAB,DMA(ln=True),Linear
SET_SIMCLR_CONV = ['ANONYMIZED',        
                   'ANONYMIZED',
                   'ANONYMIZED',
                   'ANONYMIZED',
                   'ANONYMIZED']

# set simclr resnet18 + SAB(ln=False),DeepPooler
SET_SIMCLR_RESNET = ['ANONYMIZED',
                     'ANONYMIZED',
                     'ANONYMIZED',
                     'ANONYMIZED',
                     'ANONYMIZED']


def main(args):
    device = torch.device(f"cuda:{args.gpu_id}")

    if args.data_name == "all":
        data_list = ["mini", "tiny", "cifar", "aircraft", "cars", "cub"]
    else:
        data_list = [args.data_name]

    
    for data_name in data_list:
        # data
        meta_test_ds = EpisodicDataset(data_name, 'meta_test')
        args.transform = CustomTransform('mini', args.img_size)
        
        # algorithm
        algo = get_algorithm(f'{args.algorithm}')

        # dirs
        if args.method == 'set_simclr':
            if args.model == 'conv5_64':
                dirs_list = SET_SIMCLR_CONV
            elif args.model == 'resnet18':
                dirs_list = SET_SIMCLR_RESNET
        else:
            dirs_list = eval(args.method.upper())

        mean_list, ci_list = [], []
        for dir in dirs_list:
            # encoder
            encoder = get_model(args.model, args.img_size).to(device)
            encoder.eval()
            last_hidden_size = encoder(
                torch.randn(1, 3, args.img_size, args.img_size).to(device)).shape[-1]
            encoder_state_dict = torch.load(f"{dir}/encoder_400.pth", map_location=device)
            encoder.load_state_dict(encoder_state_dict)

            # decoder
            if args.method == 'set_simclr':
                decoder = get_set_model(args, last_hidden_size, device)
                decoder_state_dict = torch.load(f"{dir}/decoder_400.pth", map_location=device)
                decoder.load_state_dict(decoder_state_dict)
            else:
                decoder = MeanPooling()
                
            mean, ci = algo.run(args, args.test_episodes, encoder, decoder, meta_test_ds, device)
            mean_list.append(mean)
            ci_list.append(ci)
        
        print(f"---------------------------------{data_name}-{args.method}-{args.shot}-{args.algorithm}--------------------------------------")
        print(np.mean(mean_list), 1.96*np.std(mean_list)/float(np.sqrt(len(mean_list))))
        print(np.mean(ci_list), 1.96*np.std(ci_list)/float(np.sqrt(len(ci_list))))


if __name__ == "__main__":
    parser = argparse.ArgumentParser('Unsupervised Meta-learning')

    parser.add_argument('--data-name', type=str, default='mini')
    parser.add_argument('--img-size', type=int, default=84)

    # Model Argument    
    parser.add_argument('--method', type=str, default='simclr')
    parser.add_argument('--algorithm', type=str, default='lr_warmstart')
    parser.add_argument('--model', type=str, default='resnet18')
    parser.add_argument('--set-model', type=str, default='SAB(ln=False),DeepPooler') 
    parser.add_argument('--num-heads', type=int, default=4)

    # Evaluation Argument
    parser.add_argument('--way', type=int, default=5)
    parser.add_argument('--support', type=int, default=5)
    parser.add_argument('--query', type=int, default=15)
    parser.add_argument('--test-episodes', type=int, default=1000)
    
    # System Argument
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--gpu-id', type=int, default=0)
    args = parser.parse_args()

    warnings.filterwarnings(action='ignore', category=ConvergenceWarning)

    main(args)