import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from utils.misc import *
from utils.coarse_set_helpers import *
import argparse
from torchvision.datasets import ImageFolder
from torchvision.models import ResNet50_Weights
from tqdm import tqdm
from collections import defaultdict
import os
import rna_classification as rna


IMAGENET_TRAIN_DIR = '/datasets/imagenet/train'
IMAGENET_VAL_DIR = '/datasets/imagenet/val'

coarse_label_set_priorities = { '85': {
                                    1: {'interval': [10, 200], 'pick_lowest': True},
                                    2: {'interval': [5, 100], 'pick_lowest': False},
                                    3: {'interval': [10, 1000], 'pick_lowest': True}
                                    },
                                '45': {
                                    1: {'interval': [61, 62], 'pick_lowest': True},
                                    2: {'interval': [64, 67], 'pick_lowest': True},
                                    3: {'interval': [16, 61], 'pick_lowest': False},
                                    4: {'interval': [64, 122], 'pick_lowest': True},
                                    5: {'interval': [7, 16], 'pick_lowest': False},
                                    6: {'interval': [63, 1000], 'pick_lowest': True}
                                    },
                                '26': {
                                    1: {'interval': [140, 211], 'pick_lowest': True},
                                    2: {'interval': [101, 140], 'pick_lowest': False},
                                    3: {'interval': [213, 300], 'pick_lowest': True},
                                    4: {'interval': [64, 70], 'pick_lowest': False},
                                    5: {'interval': [73, 99], 'pick_lowest': False},
                                    6: {'interval': [18, 61], 'pick_lowest': False},
                                    7: {'interval': [7, 16], 'pick_lowest': False},
                                    8: {'interval': [101, 1000], 'pick_lowest': True}
                                    }}

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'using {device}')

def run_one_epoch(filmed_model, optimizer):
    history = ModelHistory(experiment, config)
    top1_n_correct, top1_n_coarse_correct = 0, 0
    top5_n_correct, top5_n_coarse_correct = 0, 0
    n_data = 0
    running_loss = 0
    criterion = nn.CrossEntropyLoss()
    with tqdm(dataloaders[phase], unit='b') as tepoch:
        for data, labels in tepoch:
            tepoch.set_description(f'Epoch {epoch}, {phase}'.ljust(25))

            n_data += labels.shape[0]
            data, labels = data.to(device), labels.to(device)
            coarse_labels = lookuptable_tensor[labels].to(device)

            outputs = filmed_model(data, coarse_labels)

            loss = criterion(outputs, labels)
            if phase == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            history_temp = get_info_from_outputs(outputs, labels=labels, coarse_labels=coarse_labels, history=defaultdict(list))
            history_temp['loss'] = loss.item()
            history.update(history_temp)
            running_loss += loss.item() * labels.shape[0]


            top1_n_correct += history_temp['top1_acc'][-1] * labels.shape[0]
            top1_n_coarse_correct += history_temp['top1_coarse_acc'][-1] * labels.shape[0]

            top5_n_correct += history_temp['top5_acc'][-1] * labels.shape[0]
            top5_n_coarse_correct += history_temp['top5_coarse_acc'][-1] * labels.shape[0]

            tepoch.set_postfix(loss = running_loss / n_data, top1 = 100. * top1_n_correct / n_data, top1c = 100. * top1_n_coarse_correct / n_data, top5 = 100. * top5_n_correct / n_data, top5c = 100. * top5_n_coarse_correct / n_data) # , tl = t_loading, tf = t_forward, tb = t_backward, tc = t_calculating)

            start = max(0, len(history.concated_hist['top1_acc']) - 100)
            temp_dict = dict()
            if phase == 'train':
                temp_dict['loss' + '_' + phase + '_avbatchs'] = mean(history.concated_hist['loss'][start:])
                temp_dict['top1_acc' + '_' + phase + '_avbatchs'] = mean(history.concated_hist['top1_acc'][start:])
                temp_dict['top1_coarse_acc' + '_' + phase + '_avbatchs'] = mean(history.concated_hist['top1_coarse_acc'][start:])
                temp_dict['top5_acc' + '_' + phase + '_avbatchs'] = mean(history.concated_hist['top5_acc'][start:])
                temp_dict['top5_coarse_acc' + '_' + phase + '_avbatchs'] = mean(history.concated_hist['top5_coarse_acc'][start:])

                temp_dict['loss' + '_' + phase + '_batch'] = loss.item()
                temp_dict['top1_acc' + '_' + phase + '_batch'] = history_temp['top1_acc'][-1]
                temp_dict['top1_coarse_acc' + '_' + phase + '_batch'] = history_temp['top1_coarse_acc'][-1]
                temp_dict['top5_acc' + '_' + phase + '_batch'] = history_temp['top5_acc'][-1]
                temp_dict['top5_coarse_acc' + '_' + phase + '_batch'] = history_temp['top5_coarse_acc'][-1]

    history.additional_info['loss' + '_' + phase] = running_loss / n_data
    history.additional_info['top1_acc' + '_' + phase] = top1_n_correct / n_data
    history.additional_info['top1_coarse_acc' + '_' + phase] = top1_n_coarse_correct / n_data
    history.additional_info['top5_acc' + '_' + phase] = top5_n_correct / n_data
    history.additional_info['top5_coarse_acc' + '_' + phase] = top5_n_coarse_correct / n_data

    return history

parser = argparse.ArgumentParser('train')
parser.add_argument('--optimizer', default='Adam', type=str)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--weight_decay', type=int, default=1e-4)
parser.add_argument('--coarse_label_set_type', default='85', type=str)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--n_epoch', type=int, default=60)
parser.add_argument('--num_workers', type=int, default=32)
parser.add_argument('--save_dir', type=str, default='rna_classification_log')
parser.add_argument('--resume', type=int, default=None)
parser.add_argument('--seed', type=int, default=None)
parser.add_argument('--train_root', type=str, default=IMAGENET_TRAIN_DIR)
parser.add_argument('--val_root', type=str, default=IMAGENET_VAL_DIR)

config, _ = parser.parse_known_args()

weights = ResNet50_Weights.IMAGENET1K_V1
preprocessing = weights.transforms()

phases = ['train', 'val']
roots = {'train': config.train_root, 'val': config.val_root}
datasets = {phase: setup_dataset(preprocessing, root=roots[phase]) for phase in phases}
dataloaders = {phase: setup_dataloader(datasets[phase], config=config) for phase in phases}


os.makedirs(config.save_dir, exist_ok=True)

priorities = coarse_label_set_priorities[config.coarse_label_set_type]
lookuptable_tensor, key_hypers = get_lookuptable(priorities, table_type='tensor')
display_lookuptable(key_hypers)

rna_model = torchvision.models.resnet50(weights=weights).to(device)
rna_model = rna.RNA(rna_model).to(device)

params = rna_model.get_trainable_params()
optimizer = setup_optimizer(params, config)

experiment = 'rna' + config.coarse_label_set_type
step_temp = 0
for epoch in range(config.n_epoch):
    for phase in phases:
        history_temp = run_one_epoch(rna_model, optimizer)

    if epoch % 1 == 0:
        torch.save({
            'film_generator': rna_model.film_generator.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch
        }, os.path.join(config.save_dir, experiment + f'_epoch_{epoch}' + '_models.pth'))
        save_dict(history_temp.__dict__, experiment + f'_epoch_{epoch}' + '_history', config.save_dir)
