import random

from matplotlib import pyplot as plt
import numpy as np
from scipy.stats import spearmanr
from sklearn.preprocessing import normalize
import torch
from torch import nn
import torchvision

import vendi_score
from vendi_score import vendi, image_utils, data_utils

import pickle

from dataclasses import asdict, dataclass, field
import random
from typing import Any, Dict, List, Optional

from torchvision.models import inception_v3, Inception_V3_Weights
from torchvision import transforms

from diversity_metrics import *

@dataclass
class Example:
    x: Any
    features: Dict[str, Any] = field(default_factory=dict)
    labels: Dict[str, Any] = field(default_factory=dict)


@dataclass
class Group:
    name: str = ""
    examples: List[Example] = field(default_factory=list)
    metrics: Dict[str, Any] = field(default_factory=dict)
    features: Dict[str, Any] = field(default_factory=dict)
    Ks: Dict[str, Any] = field(default_factory=dict)
    ws: Dict[str, Any] = field(default_factory=dict)
    vs: Dict[str, Any] = field(default_factory=dict)


def to_batches(lst, batch_size):
    batches = []
    i = 0
    while i < len(lst):
        batches.append(lst[i : i + batch_size])
        i += batch_size
    return batches

def get_inception(pretrained=True, pool=True):
    if pretrained:
        weights = Inception_V3_Weights.DEFAULT
    else:
        weights = None
    model = inception_v3(
        weights=weights, transform_input=True
    ).eval()
    if pool:
        model.fc = nn.Identity()
    return model

def inception_transforms():
    return transforms.Compose(
        [
            transforms.Resize(299),
            transforms.CenterCrop(299),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.expand(3, -1, -1)),
        ]
    )

def get_embeddings(
    images,
    model=None,
    transform=None,
    batch_size=64,
    device=torch.device("cpu"),
):
    if type(device) == str:
        device = torch.device(device)
    if model is None:
        model = get_inception(pretrained=True, pool=True).to(device)
        transform = inception_transforms()
    if transform is None:
        transform = transforms.ToTensor()
    uids = []
    embeddings = []
    for batch in data_utils.to_batches(images, batch_size):
        x = torch.stack([transform(img) for img in batch], 0).to(device)
        with torch.no_grad():
            output = model(x)
        if type(output) == list:
            output = output[0]
        embeddings.append(output.squeeze().cpu().numpy())
    return np.concatenate(embeddings, 0)


def get_inception_embeddings2(images, batch_size=64, device="cpu"):
    if type(device) == str:
        device = torch.device(device)
    model = get_inception(pretrained=True, pool=True).to(device)
    transform = inception_transforms()
    return get_embeddings(
        images,
        batch_size=batch_size,
        device=device,
        model=model,
        transform=transform,
    )

def mode_dropping_groups_simultaniously_2(label_to_examples, categories, N, n_cats, cat, n_steps=10):
    # label_to_examples: the collection of images
    # categories: the categories to consider
    # N: the number of samples

    ### generally sample more observations from one category
    # first sample the "real data"
    # getting worse with the same speed

    # range through the "weight" moved to the first mode

    lst = [] # a list of samples
    cat_indices_taken={}
    for cat in categories:
        cat_indices_taken[cat]=[]

    i = len(categories)+1

    N=sum(n_cats)
    n_per_cat = min(n_cats)#N //len(categories)
    num_moved = (N - n_per_cat)//n_steps

    samples_categ={}
    for i, categ in enumerate(categories):
        if categ == cat:
            if N > 1000:
                sample=[]
                N_now = N
                while N_now > 0:
                    data = random.sample(label_to_examples[categ], min(N_now, 1000))
                    N_now -= 1000
                    sample.extend(data)
                samples_categ[categ] = sample
                #sample.extend(random.sample(label_to_examples[categ], len(data)))
            else:
                samples_categ[categ] = random.sample(label_to_examples[categ], N)
        else:
            samples_categ[categ] = random.sample(label_to_examples[categ], n_cats[i])

    #group0 = Group(1, [])
    #group0.examples += [l for l in samples_categ[cat]] ##all samples from cat1000
    #group = group0
    #group.name=0
    #lst.append(group0)

    for j, k in enumerate(np.linspace(0, n_per_cat*(len(categories)-1), n_steps)):
        k = int(round(k))
        print("----")
        #print(k)
        n_mode1 = N-((k//(len(categories)-1))*(len(categories)-1))#n_per_cat+k
        groupj = Group(n_mode1/N, [])
        print(n_mode1)
        n_samples=[k//(len(categories)-1) for c in range(len(categories)-1)]
        print(n_samples)
        
        #[max(0, min(n_per_cat, (cl+1)*n_per_cat-k)) for cl in range(len(categories)-1)]
        print("----")
        #print(sum(n_samples))
        #print(n_mode1)

        groupj.examples += [s for s in samples_categ[cat]][:n_mode1]

        c=0
        for categ in categories:
            if categ is not cat:
                if n_samples[c] > 0:
                    #print(n_samples)
                    groupj.examples += [s for s in samples_categ[categ]][:n_samples[c]]
                c+=1
        #print(len(groupj.examples))
        lst.append(groupj)
        #n_samples=[200+k, max(0, min(n_per_cat, 800-k)), max(0, min(n_per_cat, 600-k)), max(0, min(n_per_cat, 400-k)), max(0, 200-k)]
        #fake_features_out= make_blobs(n_samples=[200+k, max(0, min(200, 800-k)), max(0, min(200, 600-k)), max(0, min(200, 400-k)), max(0, 200-k)]

    return lst

def mode_dropping_groups_sequentially_2(label_to_examples, categories, N, n_cats, cat, n_steps=10):
    # label_to_examples: the collection of images
    # categories: the categories to consider
    # N: the number of samples
    # n_cats: number of samples per category
    # cat: which category is preferred

    ### generally sample more observations from one category
    # first sample the "real data"
    # getting worse with the same speed

    # range through the "weight" moved to the first mode

    lst = [] # a list of samples
    cat_indices_taken={}
    for cat in categories:
        cat_indices_taken[cat]=[]

    i = len(categories)+1

    N=sum(n_cats)
    n_per_cat = min(n_cats)#N //len(categories)
    #num_moved = (N - n_per_cat)//n_steps

    samples_categ={}
    for i, categ in enumerate(categories):
        if categ == cat:
            if N > 1000:
                sample=[]
                N_now = N
                while N_now > 0:
                    data = random.sample(label_to_examples[categ], min(N_now, 1000))
                    N_now -= 1000
                    sample.extend(data)
                samples_categ[categ] = sample
            else:
                samples_categ[categ] = random.sample(label_to_examples[categ], N)
        else:
            samples_categ[categ] = random.sample(label_to_examples[categ], n_cats[i])

    for j, k in enumerate(np.linspace(0, n_per_cat*(len(categories)-1), n_steps)):
        k = int(round(k))
        n_mode1 = max(n_cats)+k#n_per_cat+k
        n_samples=[max(0, min(n_per_cat, (cl+1)*n_per_cat-k)) for cl in range(len(categories)-1)]

        groupj = Group(n_mode1/N, [])

        print("----")
        print(n_mode1)
        print(n_samples)

        groupj.examples += [s for s in samples_categ[cat]][:n_mode1]

        c=0
        for categ in categories:
            if categ is not cat:
                if n_samples[c] > 0:
                    #print(n_samples)
                    groupj.examples += [s for s in samples_categ[categ]][:n_samples[c]]
                c+=1
        lst.append(groupj)
    return lst


def mode_dropping_groups_simultaniously(label_to_examples, categories, N, cat, n_steps=10):
    # label_to_examples: the collection of images
    # categories: the categories to consider
    # N: the number of samples

    ### generally sample more observations from one category
    # first sample the "real data"
    # getting worse with the same speed

    # range through the "weight" moved to the first mode

    lst = [] # a list of samples
    cat_indices_taken={}
    for cat in categories:
        cat_indices_taken[cat]=[]

    i = len(categories)+1


    n_per_cat = N //len(categories)
    num_moved = (N - n_per_cat)//n_steps

    samples_categ={}
    for categ in categories:
        if categ == cat:
            samples_categ[categ] = random.sample(label_to_examples[categ], N)
        else:
            samples_categ[categ] = random.sample(label_to_examples[categ], n_per_cat)

    for j, k in enumerate(np.linspace(0, N-n_per_cat, n_steps)):
        k = int(round(k))
        print("----")
        print(k)
        n_mode1 = N-((k//(len(categories)-1))*(len(categories)-1))#n_per_cat+k
        groupj = Group(n_mode1/N, [])
        print(n_mode1)
        n_samples=[k//(len(categories)-1) for c in range(len(categories)-1)]
        print(n_samples)
  
        groupj.examples += [s for s in samples_categ[cat]][:n_mode1]

        c=0
        for categ in categories:
            if categ is not cat:
                if n_samples[c] > 0:
                    groupj.examples += [s for s in samples_categ[categ]][:n_samples[c]]
                c+=1
        lst.append(groupj)

    return lst

def mode_dropping_groups_sequentially(label_to_examples, categories, N, cat, n_steps=10):

    lst = [] # a list of samples
    cat_indices_taken={}
    for cat in categories:
        cat_indices_taken[cat]=[]

    i = len(categories)+1

    n_per_cat = N //len(categories)
    num_moved = (N - n_per_cat)//n_steps

    samples_categ={}
    for categ in categories:
        if categ == cat:
            samples_categ[categ] = random.sample(label_to_examples[categ], N)
        else:
            samples_categ[categ] = random.sample(label_to_examples[categ], n_per_cat)

    for j, k in enumerate(np.linspace(0, N-n_per_cat, n_steps)):
        k = int(round(k))
        n_mode1 = n_per_cat+k
        n_samples=[max(0, min(n_per_cat, (cl+1)*n_per_cat-k)) for cl in range(len(categories)-1)]

        groupj = Group(n_mode1/N, [])

        print("----")
        print(n_mode1)
        #n_samples=[k//(len(categories)-1) for c in range(len(categories)-1)]
        print(n_samples)

        groupj.examples += [s for s in samples_categ[cat]][:n_mode1]

        c=0
        for categ in categories:
            if categ is not cat:
                if n_samples[c] > 0:
                    #print(n_samples)
                    groupj.examples += [s for s in samples_categ[categ]][:n_samples[c]]
                c+=1
        #print(len(groupj.examples))
        lst.append(groupj)
    return lst

from sklearn.preprocessing import normalize
import pandas as pd

if __name__ == '__main__':

    cifar_examples10 = image_utils.get_cifar10("samples/cifar100/original", split="test")

    device = torch.device("cpu")
    embeddings10 = get_inception_embeddings2(
        [e.x for e in cifar_examples10],
        batch_size=64,
        device=torch.device("cpu"), #pretrained=True #init_weights=False
    )
    for e, emb in zip(cifar_examples10, embeddings10):
        e.features["inception"] = emb

    with open('./data/images/cifar10_interception.pkl', 'wb') as fp:
        pickle.dump(embeddings10, fp)

    

    with open('./data/images/cifar10_interception.pkl', 'rb') as fp:
        embeddings10= pickle.load(fp)

    for e, emb in zip(cifar_examples10, embeddings10):
        e.features["inception"] = emb

    categories10 = sorted(set(e.labels["y"] for e in cifar_examples10))
    category_groups10 = {c: data_utils.Group(c, []) for c in categories10}
    for e in cifar_examples10:
        category_groups10[e.labels["y"]].examples.append(e)
    groups10 = list(category_groups10.values())

    N=1000#0#0#1000#2000#1000
    k=100#0#0#100#100#25#int(np.round(25/500*2000))
    n_steps=30#2#30
    target_scale=0.95
    for seed in range(3,20):
        print("seed")
        print(seed)
        random.seed(seed)
        cats = categories10#[:5]#random.sample(categories, 10)
        label_to_examples10 = {group.name: group.examples for group in groups10}
        #mode_dropping_groups = mode_dropping_groups_simultanious(label_to_examples, cats, N=100, cat="apple")
        #mode_dropping_groups = mode_dropping_groups_sequentially(label_to_examples10, cats, N=1000, cat="cat", n_steps=20)
        n_cats=[]
        cat=cats[seed%10]
        for r, c in enumerate(cats):
            if c==cat:
                n_cats.append(k)
            else:
                n_cats.append(k)
        mode_dropping_groups_simultanious = mode_dropping_groups_simultaniously_2(label_to_examples10, cats, 
                                                                                    n_cats=n_cats,
                                                                                    N=N, cat=cat, n_steps=n_steps)
        mode_dropping_groups_simultanious.reverse()
        print("--......--")
        mode_dropping_groups_sequential = mode_dropping_groups_sequentially_2(label_to_examples10, cats, 
                                                                                    n_cats=n_cats,
                                                                                        N=N, cat=cat, n_steps=n_steps)
        #metrics_simultanious = {}
        #metrics_sequential = {}
        #Xs=[]
        lable_values=[]
        #for group in mode_dropping_groups_simultanious:
        #    if group.name == 1:
        #        X_0 = np.stack([e.features["inception"] for e in group.examples], 0)
        for group in mode_dropping_groups_simultanious:
        #    X = np.stack([e.features["inception"] for e in group.examples], 0)
        #    Xs.append(X)
            lable_values.append(group.name)
        n_samples = len(mode_dropping_groups_simultanious)
        
        if True:
            def vendis_mode(mode_dropping_groups_simultanious, n_samples):
                metrics={}
                for idx in range(n_samples):
                    group = mode_dropping_groups_simultanious[idx]
                    X = np.stack([e.features["inception"] for e in group.examples], 0)
                    D=get_dist(normalize(X), metric="euclidean", p=2, check_for_duplicates=False)
                    sim = np.exp(-D)
                    neg_means_exp=similarity2negmeansimilarity(sim)
                    vendi_exp=vendi.score_K(sim)
                    vendi_exp_n=vendi.score_K(sim, normalize=True)
                    metrics[idx]={"vendi_n_exp": vendi_exp, "vendi_n_exp_n": vendi_exp_n, "neg_means_n_exp": neg_means_exp}
                return metrics
            
            metrics_simultanious = vendis_mode(mode_dropping_groups_simultanious, n_samples)
            metrics_simultanious = pd.DataFrame(metrics_simultanious).T
            metrics_simultanious.to_csv("./images/vendi_sim_"+str(round(seed))+"_"+str(round(target_scale,2))+".csv")

        if True:
            def load_data(idx):
                group = mode_dropping_groups_simultanious[idx]
                X = np.stack([e.features["inception"] for e in group.examples], 0)
                return X#Xs[idx]
            metrics_simultanious = calc_metrics_from_embeddings(load_data, n_samples, n_ts=10, metrics=["L2"], 
                                        reference_summaries = True, 
                                        reference_scale="reference", scale=True, absolute_area=False, nearest_k=10, target_scale=target_scale)
            metrics_simultanious["summary_statistics"]["lable_values"] = lable_values
            output_path="./images/sim_"+str(round(seed))#+"_"+str(round(target_scale,2))
            save_magnitude_results(metrics_simultanious, output_path, reference_summaries = True)
            print("sim done")


        lable_values=[]
        for group in mode_dropping_groups_sequential:
        #    X = np.stack([e.features["inception"] for e in group.examples], 0)
        #    Xs.append(X)
            lable_values.append(group.name)
        n_samples = len(mode_dropping_groups_sequential)

        if True:
            metrics_sequential = vendis_mode(mode_dropping_groups_sequential, n_samples)
            metrics_sequential = pd.DataFrame(metrics_sequential).T
            metrics_sequential.to_csv("./images/vendi_seq_"+str(round(seed))+"_"+str(round(target_scale,2))+".csv")

        #n_samples = len(Xs)
        #def load_data(idx):
        #    return Xs[idx]

        if True:
            def load_data(idx):
                group = mode_dropping_groups_sequential[idx]
                X = np.stack([e.features["inception"] for e in group.examples], 0)
                return X
            metrics_simultanious = calc_metrics_from_embeddings(load_data, n_samples, n_ts=10, metrics=["L2"], 
                                        reference_summaries = True, 
                                        reference_scale="reference", scale=True, absolute_area=False, nearest_k=10, target_scale=target_scale)
            metrics_simultanious["summary_statistics"]["lable_values"] = lable_values
            output_path="./images/seq_"+str(round(seed))#+"_"+str(round(target_scale,2))
            save_magnitude_results(metrics_simultanious, output_path, reference_summaries = True)
            print("seq done")