import pickle
import random
import argparse
import torch
import os
from collections import Counter
from omegaconf import OmegaConf

import numpy as np 
import pandas as pd
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import seaborn as sns

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ddpm_test import DDPMSampler
from save_ensemble import load_ensemble
from sample_model import sample_model, get_model, seed_everything
from make_histogram import get_min_max, norm_uncs, split_uncs


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Get Uncertainty Per Synset')
    parser.add_argument('--path', type=str, help='path to model', required=True)
    parser.add_argument('--sampler', type=str, help='which smapler to use', default='DDIM')
    parser.add_argument('--unc_branch', type=int, help='where to branch for uncertainty', 
        default=200)
    parser.add_argument('--ddim_eta', type=float, help='controls stdev for generative process', 
        default=0.00)
    # ddim_eta 0-1, 1=DDPM 
    parser.add_argument('--ddim_steps', type=int, help='number of steps to take in ddim', 
        default=200)
    parser.add_argument('--base_comp', type=int, help='comp to start from before branching', 
        default=-1)
    parser.add_argument('--quickrun', action='store_true', help='run quick test')
    parser.add_argument('--scale', type=float, help='controls the amount of unconditional guidance',
        default=5.0)
    parser.add_argument('--ensemble_name', type=str, help='which ensemble to load', default='bootstrapped')
    parser.add_argument('--full_ensemble', action='store_true', help='ensemble from scratch')
    args = parser.parse_args()
    #seed_everything(42)

    print('add train_path')
    #train_path = 
    numb_comps = 5 
    if not args.full_ensemble:
        #model = get_model(args.path)
        model = load_ensemble(args.path, True, numb_comps)
        ensemble_comps = []
        ensemble_size = model.model.diffusion_model.ensemble_size
        comp_idx = random.randint(0, (ensemble_size-1))
    else:
        comp_idx = random.randint(0, (numb_comps-1))
        ensemble_comps = [get_model(os.path.join(args.path, i)) for i in
            os.listdir(args.path) if args.ensemble_name in i]
        model = ensemble_comps[comp_idx]
        ensemble_size = len(ensemble_comps)
    with open(os.path.join(train_path, 'filelist.txt')) as f:
        filelist = f.readlines()
    count_per_synset = [f.split('/')[0] for f in filelist]
    count_per_synset = dict(Counter(count_per_synset))
    
    with open(os.path.join(train_path, 'subsets.pkl'), 'rb') as fp:
        subsets = pickle.load(fp)
    # TODO:
    #could figure out which index is which by conditioning and then checking
    human_synset = {} 
    with open(os.path.join(train_path, 'synset_human.txt')) as f:
        for line in f:
            items = line.split()
            key, values = items[0], ' '.join(items[1:])
            human_synset[key] = values
    syn1_human = {k: human_synset[k] for k in subsets[1]}
    syn1300_human = {k: human_synset[k] for k in subsets[1300]}
    idx2synset = OmegaConf.load(os.path.join(train_path, 'index_synset.yaml'))
    synset2idx = {y: x for x, y in idx2synset.items()}
    df = pd.DataFrame(human_synset.items(), columns=['synset', 'text'])
    # missing synset 'n02012849'
    # double cranes n02012849, n03126707
    classes_1 = [synset2idx[k] for k in subsets[1]]
    classes_1300 = [synset2idx[k] for k in subsets[1300]]
    #classes_1 = random.sample(classes_1, 4)
    classes_1 = [392, 708, 729, 854]
    classes_1300 = [25, 187, 448, 992]


    classes2sample = classes_1+classes_1300
    if args.sampler == 'DDIM':
        sampler = DDIMSampler(model)
    else: 
        print('Not setup for DDPM')
        sampler = DDPMSampler(model)
    ddim_steps = args.ddim_steps
    ddim_eta = args.ddim_eta
    scale = args.scale
    #n_samples_per_class = 5
    n_samples_per_class = 8 
        
    all_samples = list()
    all_classes = []
    for k , v in subsets.items():
        idx4classes = [synset2idx[syn] for syn in v]
        all_classes += idx4classes 
        all_classes.sort()
    if args.quickrun:
        all_classes = classes2sample
        #all_classes = classes2sample[0:2]
        #n_samples_per_class = 20 
        n_samples_per_class = 2 
    epi_uncs = {}
    x_T = torch.randn((n_samples_per_class,3,64,64), device=model.device) 
    with torch.no_grad():
        with model.ema_scope():
            all_ucs = []
            for _ in range(ensemble_size):
                if not ensemble_comps:
                    uc = model.get_learned_conditioning(
                        {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)},
                        comp_idx = _)
                    all_ucs.append(uc)
                else:
                    uc = ensemble_comps[_].get_learned_conditioning(
                        {ensemble_comps[_].cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)})
                    all_ucs.append(uc)

            for class_label in all_classes:
                print(f"rendering {n_samples_per_class} examples of class '{class_label}' in"\
                    f" {ddim_steps} steps and using s={scale:.2f}.")
                xc = torch.tensor(n_samples_per_class*[class_label])
                all_cs = []
                for _ in range(ensemble_size):
                    if not ensemble_comps:
                        c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)},
                            comp_idx = _)
                        all_cs.append(c)
                    else:
                        c = ensemble_comps[_].get_learned_conditioning(
                            {ensemble_comps[_].cond_stage_key: xc.to(model.device)})
                        all_cs.append(c)

                all_samples_class =[]
                samples_ddim, epi_unc, dist, dist_mat = sampler.sample(S=ddim_steps,
                                                         conditioning=all_cs[comp_idx],
                                                         batch_size=n_samples_per_class,
                                                         shape=[3, 64, 64],
                                                         verbose=False,
                                                         unconditional_guidance_scale=scale,
                                                         unconditional_conditioning=all_ucs[comp_idx],
                                                         eta=ddim_eta,
                                                         ensemble_comp=comp_idx,
                                                         return_distribution=True,
                                                         return_unc=True,
                                                         branch_split=args.unc_branch,
                                                         all_ucs=all_ucs,
                                                         all_cs=all_cs,
                                                         ensemble_comps = ensemble_comps,
                                                         x_T = x_T)
                np.save(os.path.join(args.path,f'pair_dist_class{class_label}.npy'), dist_mat.cpu().numpy())
                epi_uncs[class_label] = epi_unc
                filename_uncs = os.path.join(args.path, f'uncs_dict_eta{args.ddim_eta}'\
                    f'_branch{args.unc_branch}_sampler{args.sampler}_ddimsteps{args.ddim_steps}_scale{scale}.pkl')
                if not args.quickrun:
                    with open(filename_uncs, 'wb') as fp:
                        pickle.dump(epi_uncs, fp)
                        print('Writing uncertainty dict to file') 
    filename_uncs = os.path.join(args.path, f'uncs_dict_eta{args.ddim_eta}'\
        f'_branch{args.unc_branch}_sampler{args.sampler}_ddimsteps{args.ddim_steps}_scale{scale}.pkl')
    if not args.quickrun:
        with open(filename_uncs, 'wb') as fp:
            pickle.dump(epi_uncs, fp)
            print('Writing uncertainty dict to file') 
    ## sample examples
    mean_uncs = {}
    std_uncs = {}
    kl_exist = 'KL' in epi_uncs[187].keys()
    uncs_bin_wass, uncs_bin_kl, uncs_bin_bhatt = split_uncs(epi_uncs, subsets, synset2idx, kl_exist)
    min_wass, max_wass = get_min_max(uncs_bin_wass)
    mean_std_wass = norm_uncs(uncs_bin_wass, min_wass, max_wass)
    print(f'epi unc Wass 1 \nmean:{mean_std_wass[1]["mean"]:.3f}       std:{mean_std_wass[1]["std"]:.3f}')
    print(f'epi unc Wass 1300 \nmean:{mean_std_wass[1300]["mean"]:.3f}       std:{mean_std_wass[1300]["std"]:.3f}')
    if kl_exist:
        min_kl, max_kl = get_min_max(uncs_bin_kl)
        min_bhatt, max_bhatt = get_min_max(uncs_bin_bhatt)
        mean_std_kl = norm_uncs(uncs_bin_kl, min_kl, max_kl)
        mean_std_bhatt = norm_uncs(uncs_bin_bhatt, min_bhatt, max_bhatt)
        print(f'epi unc KL 1 \nmean: {mean_std_kl[1]["mean"]:.3f}       std:{mean_std_kl[1]["std"]:.3f}')
        print(f'epi unc KL 1300 \nmean: {mean_std_kl[1300]["mean"]:.3f}       std:{mean_std_kl[1300]["std"]:.3f}')
        print(f'epi unc Bhatt 1 \nmean: {mean_std_bhatt[1]["mean"]:.3f}       std:{mean_std_bhatt[1]["std"]:.3f}')
        print(f'epi unc Bhatt 1300 \nmean: {mean_std_bhatt[1300]["mean"]:.3f}       std:{mean_std_bhatt[1300]["std"]:.3f}')
