import json
import argparse
import glob
import os

import copy

import numpy as np
import torch as ch
import torch.multiprocessing as mp
import torch.nn as nn

# import statistical tests
from test_statistics import MMD_Stat, Luminance_Stat, RFR_Stat, Contrast_Stat
from kernel_tests import MMD
from c2st import C2ST

from helpers import load_files, dict2namedtuple
from helpers import subsample
# No need to use oversampled mixer
from helpers import mix_datasets

# import solver
from active_set import as_simplex,pgd

def all_stats(stat): 
    if stat == 'luminance': 
        return Luminance_Stat()
    elif stat == 'rfr': 
        return RFR_Stat()
    elif stat == 'mmd': 
        return MMD_Stat()
    elif stat == 'contrast': 
        return Contrast_Stat()
    elif stat == 'c2st': 
        return C2ST('resnet18')

# Functions for solver
def finite_diff(f,z,i,eps=0.05): 
    e = ch.zeros_like(z)
    e[i] = eps
    zl = ch.clamp(z-e,min=0)
    zu = z + e
    
    delta = (zu-zl)[i]

    fl = f(zl)
    fu = f(zu)
    return ((fu - fl)/delta).detach()

def custom_mmd(encoder, sample_1, sample_2, alphas=[500]): 
    # brief hack, if language use alpha=0.1
    if sample_1.size(1) == 768: 
        alphas = [0.1]
    ds1 = encoder(sample_1)
    ds2 = encoder(sample_2)
    x1,x2 = ds1.view(ds1.size(0),-1), ds2.view(ds2.size(0),-1)
    mmd = MMD(x1.size(0), x2.size(0))
    statistic = mmd(x1, x2, alphas=alphas, ret_matrix=False)
    return statistic.detach()

# function for multithreading
def process(i,start,end,counts,diff,encoder,target,ds): 

    cuda_ids = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
    device = i % len(cuda_ids)
    ds = ch.split(ds[0],ds[1])
    ds = [ds_.cuda(device) for ds_ in ds]
        
    target = target.cuda(device)
    encoder = encoder.cuda(device)
    diff = diff.cuda(device)
    counts = counts.cuda(device)

    f = None

    f_us, f_ls = [],[]
    for i in range(start,end): 
        ds_u = [ds_[:c] for j,(ds_,c) in enumerate(zip(ds,counts)) if j != i]
        ds_u.append(ds[i])
        ds_u = ch.cat(ds_u,dim=0)
        ds_l = ds_u[:-diff[i]]

        f_us.append(custom_mmd(encoder,ds_u,target))

        # small optimization: if entry is zero, lower bound is same as f(z)
        if counts[i] == 0: 
            # if first time, save f_l == f and then append it
            # otherwise, just append saved value
            if f is None: 
                f = custom_mmd(encoder,ds_l,target)
                f_ls.append(f)
            else: 
                f_ls.append(f)
        else: 
            f_ls.append(custom_mmd(encoder,ds_l,target))

    return ch.Tensor(f_us), ch.Tensor(f_ls)

# Calculate active set objective with a property that computes 
# the finite difference gradient 
class ASObjective: 
    def __init__(self,stat, sources, target, stat_kwargs={}): 
        self.stat = stat
        self.sources = sources
        self.target = target
        self.stat_kwargs = stat_kwargs

    def calculate_stat(self, synth, encoder=None, device='cuda'): 
        return self.stat(synth.to(device),self.target.to(device),encoder=encoder,**self.stat_kwargs).abs().to(synth.device).detach()


    def objective(self,z,device='cuda'): 
        synth = mix_datasets(self.sources,z)
        s = self.calculate_stat(synth)
        return s
    
    def __call__(self,*args,**kwargs): 
        return self.objective(*args,**kwargs)

    def finite_diff(self, f, z, eps, pop_size=1000, pool=None): 
        zl = ch.clamp(z-eps,min=0)
        zu = z + eps

        counts = (z*pop_size).long()
        counts_l = (zl*pop_size).long()
        counts_u = (zu*pop_size).long()
        diff = (counts_u - counts_l)

        # subsample the upper bound once for all finite differences
        ds = []
        for i in range(z.size(0)):
            ds_ss = subsample(self.sources[i],counts_u[i])
            ds.append(ds_ss)
        if pool is not None: 
            ds_contig = ch.cat(ds,axis=0)
            ds_contig.share_memory_()
            szs = [ds_.size(0) for ds_ in ds]
            ds = (ds_contig, szs)

        # for each gradient construct the corresponding dataset by subsampling the upper bound
        # if i == j then we use the entire upper bound or the entire lower bound
        # otherwise, we use the original count
        fd = []

        # FOR MMD
        encoder = self.stat.encoder
        target = self.target

        if pool is not None: 
            counts.share_memory_()
            diff.share_memory_()
            encoder = copy.deepcopy(encoder).cpu()
            encoder.share_memory()
            target.share_memory_()

            nproc = pool
            n = z.size(0)
            chunk_sz = (n-1) // nproc + 1
            indices = [i for i in range(0,n,chunk_sz)] + [n]

            starts = indices[:-1]
            ends = indices[1:]

            P = mp.Pool(processes=pool)
            fs = P.starmap(process, [(i,s,e,counts,diff,encoder,target,ds) for i,(s,e) in enumerate(zip(starts,ends)) if e>s])
            P.close()
            P.join()
            f_us, f_ls = list(ch.cat(f).to(z.device) for f in zip(*fs))
            return (f_us-f_ls)/(zu-zl)
        else: 
            # No pool
            ds_us = []
            ds_ls = []
            for i in range(z.size(0)): 
                # dataset with the ith dataset at the end
                ds_us.append(ch.cat([ds_[:c] for j,(ds_,c) in enumerate(zip(ds,counts)) if j != i] + [ds[i]],dim=0))
                ds_ls.append(ds_us[-1][:-diff[i]])

            for i in range(z.size(0)): 
                # dataset with the ith dataset at the end
                ds_u = ch.cat([ds_[:c] for j,(ds_,c) in enumerate(zip(ds,counts)) if j != i] + [ds[i]],dim=0)
                ds_l = ds_u[:-diff[i]]

                f_u = self.calculate_stat(ds_u)
                f_l = self.calculate_stat(ds_l)
                fd.append(f_u - f_l)
            return ch.stack(fd)/(zu-zl)
        
# grad function for solver
def create_grad_fn(pool, eps=0.05): 
    def grad_fn(f,z,eps=eps,pool=pool): 
        if hasattr(f,'finite_diff') and pool is not None: 
            return f.finite_diff(f,z,eps,pool=pool)
        return ch.stack([finite_diff(f,z,i) for i in range(z.size(0))])
    return grad_fn
    
class SumStats: 
    def __init__(self,*stats, hinges=None): 
        self.stats = stats
        assert len(stats) == len(hinges)
        self.hinges=hinges
    def __call__(self,*args,**kwargs): 
        return sum(ch.clamp(stat(*args,**kwargs)-h,min=0) for stat,h in zip(self.stats,self.hinges))/len(self.stats)


if __name__ == "__main__": 
    mp.set_start_method('spawn')
    parser = argparse.ArgumentParser(description='Create synthetic data from sources')
    parser.add_argument('-c', '--config', type=str, default="configs/in_cifar10_class_airplane_eps_1.json")
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--verbose', type=int, default=1)
    parser.add_argument('--out-dir', type=str, default='results/active_set')
    parser.add_argument('--stat', type=str, default='mmd')
    parser.add_argument('--pool', type=int, default=0)
    parser.add_argument('--pgd', action='store_true')
    parser.add_argument('--restart', action='store_true')
    parser.add_argument('--skip-encoder', action='store_true')

    args = parser.parse_args()

    out_dir = os.path.join(args.out_dir,os.path.splitext(os.path.basename(args.config))[0])
    os.makedirs(out_dir, exist_ok=True)
    stat = all_stats(args.stat)
    stat.encoder = stat.encoder #.cuda()
    if args.skip_encoder: 
        stat.encoder = nn.Sequential()

    with open(args.config) as f: 
        config_dict = json.load(f)
        if args.pgd: 
            config_dict['solver'] = 'pgd'
        config = dict2namedtuple(config_dict)
    np.random.seed(args.seed)
    
    sources,i2f_src = load_files(config.base_dir, config.sources)
    targets, i2f_tgt = load_files(config.base_dir, config.targets)

    if config.normalize_source: 
        sources = [s.float()/255 for s in sources]
    if config.normalize_target: 
        targets = [t.float()/255 for t in targets]
    
    assert len(config.target_dist) == len(targets)
    target_dist = ch.Tensor(config.target_dist)
    target_ds = mix_datasets(targets,target_dist)

    n = len(sources)
    # initialize to uniform if small, otherwise 
    # select a random subset to start
    if n <= 25: 
        z = ch.ones(n)/n
    else: 
        p = ch.randperm(n)
        z = ch.zeros(n)
        z[p[:25]] = 0.04

    if args.verbose: 
        for i in range(n): 
            print(f'{i}: {os.path.basename(i2f_src[i])}')

    # multithreading
    if args.pool > 0: 
        pool = args.pool #mp.Pool(processes=args.pool)
    else: 
        pool = None
    grad_fn_ = create_grad_fn(pool)
    if os.path.exists(os.path.join(out_dir, f'{args.stat}.pth')) and not args.restart: 
        print("already done, skipping")
    else: 
        with ch.no_grad(): 
            obj = ASObjective(stat, sources, target_ds, stat_kwargs=config.stat_kwargs)

            checkpoint_path = os.path.join(out_dir, f'{args.stat}.pth.latest')
            if config.solver == "simplex": 
                print("using simplex solver")
                history, z0 = as_simplex(z,obj,grad_fn=grad_fn_, verbose=args.verbose, checkpoint=checkpoint_path, restart=args.restart, **config.solver_args)
            elif config.solver == "pgd": 
                print("using pgd solver")
                history, z0 = pgd(z,obj,grad_fn=grad_fn_, verbose=args.verbose, checkpoint=checkpoint_path, **config.solver_args)
            else: 
                raise ValueError("Unknown solver")

            for i,z_ in enumerate(z0): 
                z_ = z_.item()
                if z_ > 1e-4: 
                    print(f'{i}: {os.path.basename(i2f_src[i])} -- {z_:.3f}')

            print(z0)

            ch.save({
                'z' : z0, 
                'history': history, 
                'config': config_dict, 
                'args': args,
                'i2f_src': i2f_src
            }, os.path.join(out_dir, f'{args.stat}.pth'))
        