import numpy as np
import os
import torch
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from collections import defaultdict
import pickle


def mean(ls):
    return sum(ls) / len(ls)

def setup_optimizer(params, config):
    if config.optimizer == 'SGD':
        return optim.SGD(params,
                        lr=config.lr,
                        momentum=config.momentum,
                        dampening=config.dampening,
                        weight_decay=config.weight_decay)
    elif config.optimizer == 'Adam':
        return optim.Adam(params, 
                        lr=config.lr, 
                        weight_decay=config.weight_decay)

def setup_dataset(transform, root):
    print('loading dataset...')
    return ImageFolder(root=root, transform=transform)

def setup_dataloader(dataset, config):
    print('setting up dataloader...')
    if not config.seed is None:
        g = torch.Generator()
        g.manual_seed(config.seed)
        return DataLoader(dataset, batch_size=config.batch_size, shuffle=True, generator=g, num_workers=config.num_workers)
    else:
        return DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)

def save_dict(file, filename, save_dir):
    filename = filename.replace('.pickle', '')
    save_file(file, filename, save_dir, file_format='.pickle')

def load_dict(filename, load_dir):
    filename = filename.replace('.pickle', '')
    return load_file(filename, load_dir, file_format='.pickle')

def save_file(file, filename, save_dir, file_format=''):
    path = os.path.join(save_dir, filename + file_format)
    with open(path, 'wb') as handle:
        pickle.dump(file, handle, protocol=pickle.HIGHEST_PROTOCOL)

def load_file(filename, load_dir, file_format=''):
    path = os.path.join(load_dir, filename + file_format)
    with open(path, 'rb') as handle:
        return pickle.load(handle)


def get_info_from_outputs(outputs, labels=None, coarse_labels=None, history=None):
    preds, correct_preds = [], []
    n_data = labels.shape[0]
    preds_sorted = torch.argsort(outputs, dim=1, descending=True)
    for i in range(5):
        preds.append(preds_sorted[:, i].cpu())
        correct_preds.append((preds[-1].to(labels.device) == labels).type(torch.float64).cpu())
    top1_acc, top5_acc = get_accs_from_preds(correct_preds, labels, is_one_hot=True)
    coarse_top1_sum, coarse_top5_sum = 0, 0
    for i in range(coarse_labels.shape[0]):
        coarse_label = coarse_labels[i]
        top1_temp = 1 if coarse_label[preds[0][i]] == 1 else 0
        coarse_top1_sum += top1_temp
        top5_temp = top1_temp
        for j in range(1, 5):
            if top5_temp == 1:
                break
            top5_temp = 1 if coarse_label[preds[j][i]] == 1 else top5_temp
        coarse_top5_sum += top5_temp
    
    if history is None:
        return preds, correct_preds, top1_acc, top5_acc, coarse_top1_sum / n_data, coarse_top5_sum / n_data
    else:
        history['top1_acc'].append(top1_acc)
        history['top5_acc'].append(top5_acc)
        history['top1_coarse_acc'].append(coarse_top1_sum / n_data)
        history['top5_coarse_acc'].append(coarse_top5_sum / n_data)
        return history


def get_accs_from_preds(preds, labels, is_one_hot=False, enable_top5=True):
    device = labels.device
    if enable_top5:
        assert len(preds) == 5, f'expected 5, got {len(preds)}'
    preds_one_hot = []
    preds_one_hot_sum = torch.zeros_like(preds[0]).to(device)
    for pred in preds:
        if not is_one_hot:
            preds_one_hot.append((pred.to(device) == labels).type(torch.float64))
        else:
            preds_one_hot.append(pred.to(device))
        if enable_top5:
            preds_one_hot_sum += preds_one_hot[-1]
    acc1 = preds_one_hot[0].mean().item()
    if enable_top5:
        acc5 = (preds_one_hot_sum > 0).type(torch.float64).mean().item()
        return acc1, acc5
    else:
        return acc1


class ModelHistory():
    def __init__(self, name, config):
        self.name = name
        self.config = vars(config)
        self.additional_info = defaultdict(list)
        self.reset()

    def reset(self):
        self.concated_hist = defaultdict(list)
        
    def update(self, input_dict):
        for key in input_dict.keys():
            if isinstance(input_dict[key], list):
                if len(input_dict[key]) > 1:
                    self.concated_hist[key] = self.concated_hist[key] + input_dict[key]
                else:
                    self.concated_hist[key].append(input_dict[key][0])
            else:
                self.concated_hist[key].append(input_dict[key])
    def concat_history(self, history):
        self.concated_hist = self.merge_dicts(self.concated_hist, history.concated_hist)
        self.additional_info = self.merge_dicts(self.additional_info, history.additional_info)
        

def merge_dicts(dict1, dict2):
    dict_final = defaultdict(list)
    for key in dict2.keys():
        if isinstance(dict2[key], list):
            dict_final[key] = dict1[key] + dict2[key]
        else:
            dict_final[key].append(dict1[key])
            dict_final[key].append(dict2[key])
    return dict_final  

    
