#import argparse
from logging.config import dictConfig
import pickle
import hydra
from omegaconf import DictConfig, OmegaConf
import torch
from torch import nn
import numpy as np
from scipy import linalg
from tqdm import tqdm
from os import path
import os
import sys
# module_path = os.path.abspath(os.path.join('/home/hscho/newvae/models/stylegan2/stylegan2-pytorch'))
# if module_path not in sys.path:
#     sys.path.append(module_path)

from pathlib import Path
rank_dict = {'1':5,'2':60,'3':71,'4':67,'5':59,'6':51,'7':43,'8':36}

def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
    cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)

    if not np.isfinite(cov_sqrt).all():
        print('product of cov matrices is singular')
        offset = np.eye(sample_cov.shape[0]) * eps
        cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))

    if np.iscomplexobj(cov_sqrt):
        if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
            m = np.max(np.abs(cov_sqrt.imag))

            raise ValueError(f'Imaginary component {m}')

        cov_sqrt = cov_sqrt.real

    mean_diff = sample_mean - real_mean
    mean_norm = mean_diff @ mean_diff

    trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)

    fid = mean_norm + trace

    return fid

@hydra.main(config_path="conf", config_name="config")
def my_app(cfg: DictConfig) -> None:
    device = 'cuda'    
    args = cfg.args
    if args.generator == 'stylegan1':
        sublayer = args.sublayer_str
    elif args.generator == 'stylegan2':
        sublayer = args.sublayer_num
    
    img_root = Path("/home/hscho/newvae/notebooks/out/fid")
    if args.feature:
        img_root = img_root / 'iter' 
        tensor_paths = [(img_root / f'lbsn_features_comp{0}_ptb{args.perturb:.1f}_n{args.n_sample}_sublayer{sublayer}_{args.generator}.pt', "Local_basis_SN"),
                        (img_root / f'gssn_features_comp{0}_ptb{args.perturb:.1f}_n{args.n_sample}_sublayer{sublayer}_{args.generator}.pt', "GANspace_SN")]

        other_stats_dir = "/home/hscho/newvae/notebooks/"
        other_stats = "ffhq_legacy_pytorch_trainval70k_1024.npz"
        print("Modified!!!")
        stats = np.load(os.path.join(other_stats_dir, other_stats))
        real_mean, real_cov = stats["mu"], stats["sigma"]
        for tensor_path, method in tensor_paths:
            savetxtfile = Path(f'/home/hscho/newvae/notebooks/fid_log/iter/{method}_comp{0}_ptb{args.perturb:.1f}_n{args.n_sample}_sublayer{sublayer}_{args.generator}.txt')
            # Path(f'/home/hscho/newvae/notebooks/fid_log/iter/{method}_rank{rank_dict[sublay_str]}_ptb{args.perturb:.1f}_n{args.n_sample}_sublayer{args.sublayer_num}_svtr{0.01}.txt')
            if tensor_path.exists() and not savetxtfile.exists():
                with open(f'/home/hscho/newvae/notebooks/fid_log/iter/{method}_comp{0}_ptb{args.perturb:.1f}_n{args.n_sample}_sublayer{sublayer}_{args.generator}.txt','w') as f:
                #open(f'/home/hscho/newvae/notebooks/fid_log/iter/{method}_rank{rank_dict[sublay_str]}_ptb{args.perturb:.1f}_n{args.n_sample}_sublayer{args.sublayer_num}_svtr{0.01}.txt','w') as f:
                    features = torch.load(tensor_path).numpy()
                    print(f'extracted {features.shape[0]} features')
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
                    print(f"    ", tensor_path)
                    print('   fid:', fid)
                    print()
                    print(tensor_path, file=f)
                    print('fid:', fid, file=f)
                    print(file=f)
                    
            else:
                if savetxtfile.exists():
                    print(savetxtfile, ' already EXIST!')
                else:
                    print(tensor_path, ' does NOT EXIST!')
            
if __name__ == '__main__':
    my_app()