import os
import pickle

import numpy as np
import torch
import torch.nn as nn
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
from omegaconf import OmegaConf

mpl.rc('font',family='Times New Roman')

if __name__ == '__main__':
    print('plot img diff graph and latex for table')
    print('add meta_path')
    #meat_path = 
    with open(os.path.join(meta_path, 'filelist.txt')) as f:
        filelist = f.readlines()
    #count_per_synset = [f.split('/')[0] for f in filelist]
    #count_per_synset = dict(Counter(count_per_synset))
    with open(os.path.join(meta_path, 'subsets.pkl'), 'rb') as fp:
        subsets = pickle.load(fp)
    # TODO:
    #could figure out which index is which by conditioning and then checking
    human_synset = {}
    with open(os.path.join(meta_path, 'synset_human.txt')) as f:
        for line in f:
            items = line.split()
            key, values = items[0], ' '.join(items[1:])
            human_synset[key] = values
    syn1_human = {k: human_synset[k] for k in subsets[1]}
    syn1300_human = {k: human_synset[k] for k in subsets[1300]}
    syn100_human = {k: human_synset[k] for k in subsets[100]}
    syn10_human = {k: human_synset[k] for k in subsets[10]}
    idx2synset = OmegaConf.load(os.path.join(meta_path, 'index_synset.yaml'))
    synset2idx = {y: x for x, y in idx2synset.items()}
    df = pd.DataFrame(human_synset.items(), columns=['synset', 'text'])
    # missing synset 'n02012849'
    # double cranes n02012849, n03126707
    classes_1 = [synset2idx[k] for k in subsets[1]]
    classes_10 = [synset2idx[k] for k in subsets[10]]
    classes_100 = [synset2idx[k] for k in subsets[100]]
    classes_1300 = [synset2idx[k] for k in subsets[1300]]
    model_path = os.path.join(meta_path, 'bootstrapped_imagenet_10')
    unc_paths = os.listdir(model_path)
    unc_paths = [p for p in unc_paths if 'samps_uncs_branch' in p]
    mse_mean = {}
    mse_std = {}
    for up in unc_paths:
        img_dir = os.path.join(os.path.join(model_path, up), 'images')
        img_files = os.listdir(img_dir)
        imgs = []
        for f in img_files:
            img = torch.load(os.path.join(img_dir, f))
            imgs.append(img)
        imgs = torch.stack(imgs)
        imgs = imgs[:,:5,:,:,:,:]
        mu_ensemble = imgs.mean(1)
        mse_dict = {}
        for i in range(imgs.shape[0]):
            class_imgs = imgs[i]
            mu_class = mu_ensemble[i].repeat(5,1,1,1,1)
            crit = nn.MSELoss()
            mse = crit(class_imgs, mu_class)
            mse_dict[i] = mse
        mse_1 = [mse_dict[c] for c in classes_1 if c in mse_dict.keys()]
        mse_10 = [mse_dict[c] for c in classes_10 if c in mse_dict.keys()]
        mse_100 = [mse_dict[c] for c in classes_100 if c in mse_dict.keys()]
        mse_1300 = [mse_dict[c] for c in classes_1300 if c in mse_dict.keys()]
        print(f'{up}')
        print(mse_1)
        #print(mse_10)
        print(mse_100)
        #print(mse_1300)
        bp = int(up.split('branch')[1])
        mse_mean[bp] = {1:torch.stack(mse_1).mean(), 10:torch.stack(mse_10).mean(), 100:torch.stack(mse_100).mean(), 
                1300:torch.stack(mse_1300).mean()}
        mse_std[bp] = {1:torch.stack(mse_1).std(), 10:torch.stack(mse_10).std(), 100:torch.stack(mse_100).std(), 
                1300:torch.stack(mse_1300).std()}
    
    import pdb; pdb.set_trace()
