import pickle
import random
import argparse
import torch
import os
from collections import Counter
from omegaconf import OmegaConf

import numpy as np 
import pandas as pd
from PIL import Image, ImageDraw, ImageFont 
from einops import rearrange
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import seaborn as sns

from save_ensemble import load_ensemble
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ddpm_test import DDPMSampler


def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt)#, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.cuda()
    model.eval()
    return model


def get_model(path):
    timestamp = path.split('/')[-1].split('_')[0]
    config_path = os.path.join(path, f'configs/{timestamp}-project.yaml')
    config = OmegaConf.load(config_path)
    #model_path = os.path.join(path, f'checkpoints/epoch=000012.ckpt')
    model_path = os.path.join(path, f'checkpoints/last.ckpt')
    model = load_model_from_config(config, model_path)
    return model

def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def sample_model(model, n_samples_per_class, classes2sample, ddim_steps, ddim_eta, scale, 
        ensemble_size, path, idx2synset, subsets, unc_branch, sampler, comp_idx, 
        syn1_human, syn10_human, syn100_human, syn1300_human, ensemble_comps,
        numbfilesbycomp, bin_images=False):
    if bin_images:
        path = os.path.join(path, 'images_per_bin')
    with torch.no_grad():
        with model.ema_scope():
            print('feed in all condtioned and use appropriate one at branch')
            all_ucs = []
            for _ in range(ensemble_size):
                if not ensemble_comps:
                    uc = model.get_learned_conditioning(
                        {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)},
                        comp_idx = _)
                    all_ucs.append(uc)
                else:
                    uc = ensemble_comps[_].get_learned_conditioning(
                        {ensemble_comps[_].cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)})
                    all_ucs.append(uc)
                 
            for class_label in classes2sample:
                print(f"rendering {n_samples_per_class} examples of class '{class_label}' in"\
                    f" {ddim_steps} steps and using s={scale:.2f}.")
                xc = torch.tensor(n_samples_per_class*[class_label])
                all_cs = []
                for _ in range(ensemble_size):
                    if not ensemble_comps:
                        c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)}, 
                            comp_idx = _)
                        all_cs.append(c)
                    else:
                        c = ensemble_comps[_].get_learned_conditioning(
                            {ensemble_comps[_].cond_stage_key: xc.to(model.device)})
                        all_cs.append(c)
                
                all_samples_class =[]
                samples_ddim, epi_unc, intermediates, dist_mat = sampler.sample(S=ddim_steps,
                                                 conditioning=all_cs[comp_idx],
                                                 batch_size=n_samples_per_class,
                                                 shape=[3, 64, 64],
                                                 verbose=False,
                                                 unconditional_guidance_scale=scale,
                                                 unconditional_conditioning=all_ucs[comp_idx], 
                                                 eta=ddim_eta,
                                                 ensemble_comp=comp_idx,
                                                 return_distribution=True,
                                                 return_unc=True,
                                                 branch_split=unc_branch, 
                                                 all_ucs=all_ucs,
                                                 all_cs=all_cs, 
                                                 ensemble_comps = ensemble_comps)
                samples_ddim = samples_ddim[:2,:,:,:,:]
                if len(samples_ddim.shape)== 5:
                    samples_ddim = samples_ddim.reshape(-1, samples_ddim.shape[2], 
                        samples_ddim.shape[3], samples_ddim.shape[4])

                x_samples_ddim = model.decode_first_stage(samples_ddim)
                #x_samples_ddim = model.decode_first_stage(samples_ddim[:,:,:,:])
                x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                             min=0.0, max=1.0)
                all_samples_class.append(x_samples_ddim)
                
                grid = torch.stack(all_samples_class, 0)
                grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                grid = make_grid(grid, nrow=n_samples_per_class)

                # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                grid_class_img = Image.fromarray(grid.astype(np.uint8))
                numb_samples = numbfilesbycomp[class_label]
                image_height = int(grid_class_img.height/ensemble_size)
                #image_width = grid_class_img.width+int(grid_class_img.width/n_samples_per_class)
                image_width = grid_class_img.width
                new_col_width = int(grid_class_img.width/n_samples_per_class)
                combined_image = Image.new("RGB", (image_width, image_height*ensemble_size), 
                    (255, 255, 255))
                font_color = (0, 0, 0)
                font_size = 24
                border_width = 1 
                font = ImageFont.truetype("Roboto-Thin.ttf", font_size)
                draw = ImageDraw.Draw(combined_image)
                '''for row in range(ensemble_size):
                    # Calculate the position for the current image
                    image_x = grid_class_img.width
                    image_y = row * image_height

                    # Create a new image with a white background for the current cell
                    image = Image.new("RGB", (new_col_width, image_height), (255, 255, 255))

                    # Create a draw object for the current image
                    image_draw = ImageDraw.Draw(image)

                    # Calculate the position for the text in the middle of the current image
                    text_width, text_height = image_draw.textsize(f'{numb_samples[row]} images', 
                        font=font)
                    text_x = (new_col_width - text_width) // 2
                    text_y = (image_height - text_height) // 2

                    # Write the text in the middle of the current image
                    image_draw.text((text_x, text_y), f'{numb_samples[row]} images', 
                        font=font, fill=font_color)
                    
                    border_box = [(0, 0), (new_col_width - 1, image_height - 1)]
                    image_draw.rectangle(border_box, outline=(0, 0, 0), width=border_width)
                    # Paste the current image onto the grid image at the calculated position
                    combined_image.paste(image, (image_x, image_y))
                '''
                if idx2synset[class_label] in subsets[1]:
                    file_name = os.path.join(path, (f'{syn1_human[idx2synset[class_label]]}'\
                        f'_1images_{idx2synset[class_label]}_uncbranch{unc_branch}_scale{scale}_eta{ddim_eta}.png'))
                elif idx2synset[class_label] in subsets[10]:
                    file_name = os.path.join(path, (f'{syn10_human[idx2synset[class_label]]}'\
                        f'_10images_{idx2synset[class_label]}_uncbranch{unc_branch}_scale{scale}_eta{ddim_eta}.png'))
                elif idx2synset[class_label] in subsets[100]:
                    file_name = os.path.join(path, (f'{syn100_human[idx2synset[class_label]]}'\
                        f'_100images_{idx2synset[class_label]}_uncbranch{unc_branch}_scale{scale}_eta{ddim_eta}.png'))
                elif idx2synset[class_label] in subsets[1300]:
                    file_name = os.path.join(path, (f'{syn1300_human[idx2synset[class_label]]}'\
                        f'_1300images_{idx2synset[class_label]}_uncbranch{unc_branch}_scale{scale}_eta{ddim_eta}.png'))
                combined_image.paste(grid_class_img, (0, 0))
                combined_image.save(file_name)
                print(file_name)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Get Uncertainty Per Synset')
    parser.add_argument('--path', type=str, help='path to model', required=True)
    parser.add_argument('--sampler', type=str, help='which smapler to use', default='DDIM')
    parser.add_argument('--ddim_eta', type=float, help='controls stdev for generative process', 
        default=0.00)
    # ddim_eta 0-1, 1=DDPM 
    parser.add_argument('--scale', type=float, help='controls the amount of unconditional guidance', 
        default=5.0)
    # higher scale less diversity
    parser.add_argument('--ddim_steps', type=int, help='number of steps to take in ddim', 
        default=200)
    parser.add_argument('--unc_branch', type=int, help='when to split for generative proccess', 
        default=200)
    parser.add_argument('--ensemble_name', type=str, help='which ensemble to load', default='bootstrapped')
    parser.add_argument('--full_ensemble', action='store_true', help='ensemble from scratch')
    args = parser.parse_args()
    seed_everything(42)

    print('add train_path')
    #train_path = 
    imagenet = True
    numb_comps = 10
    if not args.full_ensemble:
        #model = get_model(args.path)
        model = load_ensemble(args.path, imagenet, numb_comps)
        ensemble_comps = []
        ensemble_size = model.model.diffusion_model.ensemble_size
        comp_idx = random.randint(0, (ensemble_size-1))
    else:
        comp_idx = random.randint(0, (numb_comps-1))
        ensemble_comps = [get_model(os.path.join(args.path, i)) for i in 
            os.listdir(args.path) if args.ensemble_name in i] 
        model = ensemble_comps[comp_idx]
        ensemble_size = len(ensemble_comps) 
    with open(os.path.join(train_path, 'filelist.txt')) as f:
        filelist = f.readlines()
    count_per_synset = [f.split('/')[0] for f in filelist]
    count_per_synset = dict(Counter(count_per_synset))
    
    with open(os.path.join(train_path, 'subsets.pkl'), 'rb') as fp:
        subsets = pickle.load(fp)
    # TODO:
    #could figure out which index is which by conditioning and then checking
    human_synset = {} 
    with open(os.path.join(train_path, 'synset_human.txt')) as f:
        for line in f:
            items = line.split()
            key, values = items[0], ' '.join(items[1:])
            human_synset[key] = values
    syn1_human = {k: human_synset[k] for k in subsets[1]}
    syn1300_human = {k: human_synset[k] for k in subsets[1300]}
    syn100_human = {k: human_synset[k] for k in subsets[100]}
    syn10_human = {k: human_synset[k] for k in subsets[10]}
    idx2synset = OmegaConf.load(os.path.join(train_path, 'index_synset.yaml'))
    synset2idx = {y: x for x, y in idx2synset.items()}
    df = pd.DataFrame(human_synset.items(), columns=['synset', 'text'])
    # missing synset 'n02012849'
    # double cranes n02012849, n03126707
    classes_1 = [synset2idx[k] for k in subsets[1]]
    classes_10 = [synset2idx[k] for k in subsets[10]]
    classes_100 = [synset2idx[k] for k in subsets[100]]
    classes_1300 = [synset2idx[k] for k in subsets[1300]]
    #classes_1 = random.sample(classes_1, 4)
    #print('redo class values as the subsets have changed')
    #import pdb; pdb.set_trace()
    classes_1 = [] 
    classes_10 = []
    classes_100 = []
    classes_1300 = [992]
    #classes_1 = [392, 708, 729, 854]
    #classes_10 = [120, 445, 726, 943]
    #classes_100 = [24, 187, 830, 995]
    #classes_1300 = [25, 447, 991, 992]

    print('add base_dir')
    #base_dir = 
    filelists = [f'filelist_comp{i}.txt' for i in range(ensemble_size)] 
    classes2sample = classes_1+classes_10+classes_100+classes_1300
    synsets2sample = {i:idx2synset[i] for i in classes2sample}
    comp_files =[]
    for f in filelists:
        with open(os.path.join(base_dir, f)) as fl:
            comp_filelist = fl.readlines()
        comp_files.append(comp_filelist)
    numbfilesbycomp = {}
    for k,v in synsets2sample.items():
        numb_files = []
        for cf in comp_files:
            numb_files.append(len([i for i in cf if v in i]))
        numbfilesbycomp[k]=numb_files
    if args.sampler == 'DDIM':
        sampler = DDIMSampler(model)
    else: 
        print('Not setup for DDPM')
        sampler = DDPMSampler(model)
    ddim_steps = args.ddim_steps
    ddim_eta = args.ddim_eta
    scale = args.scale
    #n_samples_per_class = 5 
    n_samples_per_class = 1 
    
    pic_path = args.path
    if args.full_ensemble:
        pic_path = os.path.join(args.path, f'pics_fullenesemble_{numb_comps}')
    ## sample examples
    sample_model(model, n_samples_per_class, classes2sample, ddim_steps, ddim_eta, scale,
        ensemble_size, pic_path, idx2synset, subsets, args.unc_branch, sampler, comp_idx, 
        syn1_human, syn10_human, syn100_human, syn1300_human, ensemble_comps, numbfilesbycomp,
        bin_images=False)
