import pickle
import random
import argparse
import torch
import os
from collections import Counter
from omegaconf import OmegaConf

import numpy as np 
import pandas as pd
from PIL import Image, ImageDraw, ImageFont
from einops import rearrange
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import seaborn as sns

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ddpm_test import DDPMSampler
from save_ensemble import load_ensemble
from sample_model import sample_model, get_model, seed_everything
from make_histogram import get_min_max, norm_uncs, split_uncs
from uncertainty_estimation.uncertainty_estimators import pairwise_exp



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Get Uncertainty Per Synset')
    parser.add_argument('--path', type=str, help='path to model', required=True)
    parser.add_argument('--sampler', type=str, help='which smapler to use', default='DDIM')
    parser.add_argument('--unc_branch', type=int, help='where to branch for uncertainty', 
        default=0)
    parser.add_argument('--ddim_eta', type=float, help='controls stdev for generative process', 
        default=0.00)
    # ddim_eta 0-1, 1=DDPM 
    parser.add_argument('--ddim_steps', type=int, help='number of steps to take in ddim', 
        default=200)
    parser.add_argument('--base_comp', type=int, help='comp to start from before branching', 
        default=-1)
    parser.add_argument('--scale', type=float, help='controls the amount of unconditional guidance',
        default=5.0)
    parser.add_argument('--ensemble_name', type=str, help='which ensemble to load', default='bootstrapped')
    parser.add_argument('--full_ensemble', action='store_true', help='ensemble from scratch')
    args = parser.parse_args()
    seed_everything(42)

    print('add train_path')
    #train_path = 
    numb_comps = 10 
    if not args.full_ensemble:
        #model = get_model(args.path)
        model = load_ensemble(args.path, True, numb_comps)
        ensemble_comps = []
        ensemble_size = model.model.diffusion_model.ensemble_size
        comp_idx = random.randint(0, (ensemble_size-1))
    else:
        comp_idx = random.randint(0, (numb_comps-1))
        ensemble_comps = [get_model(os.path.join(args.path, i)) for i in
            os.listdir(args.path) if args.ensemble_name in i]
        model = ensemble_comps[comp_idx]
        ensemble_size = len(ensemble_comps)
    with open(os.path.join(train_path, 'filelist.txt')) as f:
        filelist = f.readlines()
    count_per_synset = [f.split('/')[0] for f in filelist]
    count_per_synset = dict(Counter(count_per_synset))
    
    with open(os.path.join(train_path, 'subsets.pkl'), 'rb') as fp:
        subsets = pickle.load(fp)
    # TODO:
    #could figure out which index is which by conditioning and then checking
    human_synset = {} 
    with open(os.path.join(train_path, 'synset_human.txt')) as f:
        for line in f:
            items = line.split()
            key, values = items[0], ' '.join(items[1:])
            human_synset[key] = values
    syn1_human = {k: human_synset[k] for k in subsets[1]}
    syn1300_human = {k: human_synset[k] for k in subsets[1300]}
    idx2synset = OmegaConf.load(os.path.join(train_path, 'index_synset.yaml'))
    synset2idx = {y: x for x, y in idx2synset.items()}
    df = pd.DataFrame(human_synset.items(), columns=['synset', 'text'])
    # missing synset 'n02012849'
    # double cranes n02012849, n03126707
    classes_1 = [synset2idx[k] for k in subsets[1]]
    classes_10 = [synset2idx[k] for k in subsets[10]]
    classes_100 = [synset2idx[k] for k in subsets[100]]
    classes_1300 = [synset2idx[k] for k in subsets[1300]]
    #classes_1 = random.sample(classes_1, 4)
    #print('redo class values as the subsets have changed')
    #import pdb; pdb.set_trace()
    #classes_1 = random.sample(classes_1, 10)
    #classes_10 = random.sample(classes_10, 10)
    #classes_100 = random.sample(classes_100, 10)
    classes_1300 = random.sample(classes_1300, 100)
    print(classes_1)
    print(classes_10)
    print(classes_100)
    print(classes_1300)
    #classes_1 = [392, 708, 729, 854]
    #classes_10 = [120, 445, 726, 943]
    #classes_100 = [24, 187, 830, 995]
    #classes_1300 = [25, 447, 991, 992]

    classes2sample=[]
    for idx_1, idx_10, idx_100, idx_1300 in zip(classes_1, classes_10, classes_100, classes_1300):
        classes2sample.append((idx_1, idx_10, idx_100, idx_1300)) 
    #classes2sample = classes_1+classes_10+classes_100+classes_1300
    if args.sampler == 'DDIM':
        sampler = DDIMSampler(model)
    else: 
        print('Not setup for DDPM')
        sampler = DDPMSampler(model)
    ddim_steps = args.ddim_steps
    ddim_eta = args.ddim_eta
    scale = args.scale
    n_samples_per_class = 1

    print('add base_dir')
    #base_dir = 
    filelists = [f'filelist_comp{i}.txt' for i in range(ensemble_size)]
    synsets2sample = {i:idx2synset[i] for i in list(sum(classes2sample, ()))}

    comp_files =[]
    for f in filelists:
        with open(os.path.join(base_dir, f)) as fl:
            comp_filelist = fl.readlines()
        comp_files.append(comp_filelist)
    numbfilesbycomp = {}
    for k,v in synsets2sample.items():
        numb_files = []
        for cf in comp_files:
            numb_files.append(len([i for i in cf if v in i]))
        numbfilesbycomp[k]=numb_files

    all_classes = classes2sample
    with torch.no_grad():
        with model.ema_scope():
            all_ucs = []
            for _ in range(ensemble_size):
                if not ensemble_comps:
                    uc = model.get_learned_conditioning(
                        {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)},
                        comp_idx = _)
                    all_ucs.append(uc)
                else:
                    uc = ensemble_comps[_].get_learned_conditioning(
                        {ensemble_comps[_].cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)})
                    all_ucs.append(uc)
            
            x_T = torch.randn((n_samples_per_class,3,64,64), device=model.device) 
            j = 0
            for label_batch in classes2sample:
                batch_latent = []
                batch_unc_latent = []
                batch_pix = []
                batch_unc_pix = []
                batch_numb_samples = []
                print('make grid')

                for class_label in label_batch:
                    print(f"rendering {n_samples_per_class} examples of class '{class_label}' in"\
                        f" {ddim_steps} steps and using s={scale:.2f}.")
                    xc = torch.tensor(n_samples_per_class*[class_label])
                    all_cs = []
                    for _ in range(ensemble_size):
                        if not ensemble_comps:
                            c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)},
                                comp_idx = _)
                            all_cs.append(c)
                        else:
                            c = ensemble_comps[_].get_learned_conditioning(
                                {ensemble_comps[_].cond_stage_key: xc.to(model.device)})
                            all_cs.append(c)
                    seed_everything(42)

                    samples_ddim, epi_unc, inter, dist_mat  = sampler.sample(S=ddim_steps,
                                                     conditioning=all_cs[comp_idx],
                                                     batch_size=n_samples_per_class,
                                                     shape=[3, 64, 64],
                                                     verbose=False,
                                                     unconditional_guidance_scale=scale,
                                                     unconditional_conditioning=all_ucs[comp_idx],
                                                     eta=ddim_eta,
                                                     ensemble_comp=comp_idx,
                                                     return_distribution=True,
                                                     return_unc=True,
                                                     branch_split=args.unc_branch,
                                                     all_ucs=all_ucs,
                                                     all_cs=all_cs,
                                                     ensemble_comps=ensemble_comps,
                                                     unc_per_pixel=True, 
                                                     x_T=x_T)
                    if len(samples_ddim.shape)== 5:
                        samples_ddim = samples_ddim.reshape(-1, samples_ddim.shape[2],
                            samples_ddim.shape[3], samples_ddim.shape[4])

                    x_samples_ddim = model.decode_first_stage(samples_ddim)
                    #x_samples_ddim = model.decode_first_stage(samples_ddim[:,:,:,:])
                    x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0,
                                                 min=0.0, max=1.0)
                    batch_latent.append(samples_ddim)
                    batch_unc_latent.append(epi_unc)
                    batch_pix.append(x_samples_ddim)
                    unc_pix, _ = pairwise_exp(x_samples_ddim.unsqueeze(1), 0, 'Wass_0', numb_comps, unc_per_pixel=True)
                    batch_unc_pix.append(unc_pix)
                    batch_numb_samples.append(np.mean(numbfilesbycomp[class_label]))
                pics = [p[0,:,:,:] for p in batch_pix]
                grid = torch.stack(pics, 0)
                #grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                grid = make_grid(grid, nrow=1)

                # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                grid_class_img = Image.fromarray(grid.astype(np.uint8))
                image_height = grid_class_img.height
                image_width = 4*grid_class_img.width
                combined_image = Image.new("RGB", (image_width, image_height),
                    (255, 255, 255))
                font_color = (0, 0, 0)
                font_size = 20
                border_width = 1
                font = ImageFont.truetype("Roboto-Thin.ttf", font_size)
                draw = ImageDraw.Draw(combined_image)
                combined_image.paste(grid_class_img, (grid_class_img.width, 0))
                # Define the text for each image

                texts = [idx2synset[l] for l in label_batch]
                single_image_height = int(image_height/len(texts))+1
                single_image_width = grid_class_img.width
                for row in range(len(texts)):
                    # Calculate the position for the current image
                    image_x = 0 
                    image_y = row * single_image_height

                    # Create a new image with a white background for the current cell
                    image = Image.new("RGB", (single_image_width, single_image_height), (255, 255, 255))

                    # Create a draw object for the current image
                    image_draw = ImageDraw.Draw(image)
                    
                    # Calculate the position for the text in the middle of the current image
                    human_txt = human_synset[texts[row]].split(',')[0]
                    text_width, text_height = image_draw.textsize(f'{human_txt}',
                        font=font)
                    text_x = (single_image_width - text_width) // 2
                    text_y = (single_image_height - text_height) // 2

                    # Write the text in the middle of the current image
                    image_draw.text((text_x, text_y), f'{human_txt}',
                        font=font, fill=font_color)

                    border_box = [(0, 0), (single_image_width - 1, single_image_height - 1)]
                    image_draw.rectangle(border_box, outline=(0, 0, 0), width=border_width)
                    # Paste the current image onto the grid image at the calculated position
                    combined_image.paste(image, (image_x, image_y))
                for row in range(len(label_batch)):
                    # Calculate the position for the current image
                    image_x = 3*single_image_width
                    image_y = row * single_image_height

                    # Create a new image with a white background for the current cell
                    image = Image.new("RGB", (single_image_width, single_image_height), (255, 255, 255))

                    # Create a draw object for the current image
                    image_draw = ImageDraw.Draw(image)

                    # Calculate the position for the text in the middle of the current image
                    text_width, text_height = image_draw.textsize(f'{batch_numb_samples[row]} images',
                        font=font)
                    text_x = (single_image_width - text_width) // 2
                    text_y = (single_image_height - text_height) // 2

                    # Write the text in the middle of the current image
                    image_draw.text((text_x, text_y), f'{batch_numb_samples[row]} images',
                        font=font, fill=font_color)

                    border_box = [(0, 0), (single_image_width - 1, single_image_height - 1)]
                    image_draw.rectangle(border_box, outline=(0, 0, 0), width=border_width)
                    # Paste the current image onto the grid image at the calculated position
                    combined_image.paste(image, (image_x, image_y))
                max_unc = torch.stack(batch_unc_pix).max()
                min_unc = torch.stack(batch_unc_pix).min()
                for row in range(len(label_batch)):
                    uncs = batch_unc_pix[row][0,:,:,:]
                    uncs = uncs.mean(0)
                    #max_unc = uncs.max()
                    #min_unc = uncs.min()
                    uncs = (uncs-min_unc)/(max_unc-min_unc)
                    px = 1/plt.rcParams['figure.dpi']
                    plt.subplots(figsize=(single_image_width*px, single_image_height*px))
                    plt.imshow(uncs.cpu().numpy(), cmap='cividis', interpolation='nearest', vmin=min_unc, vmax=max_unc)
                    plt.tick_params(left = False, right = False, labelleft = False, labelbottom = False, bottom = False)
                    plt.tight_layout(pad=0.0)
                    plt.savefig(os.path.join(args.path,'heat_map.png'))
                    im = Image.open(os.path.join(args.path,'heat_map.png'))
                    combined_image.paste(im, (single_image_width*2, single_image_height*row))
                file_name = os.path.join(os.path.join(args.path,'pixel_unc'), 
                    (f'pixel_uncertainty_{j}.png'))
                combined_image.save(file_name)
                j += 1
