import os
os.environ["PYTHONHASHSEED"] = "42"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
 
import torch
import numpy as np
import random
import scipy
from tqdm import tqdm
from scipy.special import lambertw # for importance sampling
import matplotlib.pyplot as plt
from pathlib import Path
import copy
from pprint import pprint
import yaml
import argparse
import logging
import tempfile
from pathlib import Path
import pytorch_lightning as pl
import math
from torch.utils.data import TensorDataset, DataLoader, random_split
import re
 

import os
import sys
sys.path.append('../')
 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'


from functional_fm import *
from models.fno import FNO 
from util.ema import EMA
from util.util import load_checkpoint, add_config_to_argparser, plot_spectrum_comparison, seed_everything


 
from functional_kl import *
 
from kl_functions_torch import noise_x0_class # sample_X_data
basis = 'fourier' 

################################################################################
# Create data
################################################################################
def getX0X1(data_params, n_samples, upsample,
            ):

    global basis
    
    M = data_params['general']['M']
    D = data_params['general']['D']
    
    x0source = data_params['general']['x0source']
    amp_way = data_params['general']['amp_way']
   

    white_noise = data_params['X0_GP_matern']['white_noise']

 
    x1_org   = torch.from_numpy(np.load(data_params['X1']['A_path'][int(M*upsample)])).float().to( device=device)
    x1_org_B = torch.from_numpy(np.load(data_params['X1']['B_path'][int(M*upsample)])).float().to( device=device)
    print(x1_org.shape)
 
    # mu = x1_org.mean(dim=-1) # dim=0, keepdim=True)
    mu      = x1_org.mean(dim=[0,1], keepdim=True) # [1, 1, 3])
    std_x1A = x1_org.std(dim=[0,1],  keepdim=True) # [1, 1, 3]) 


 
    x1_time_mean0   = ( x1_org    - mu )  / std_x1A   
    x1_time_mean0_B = ( x1_org_B  - mu )  / std_x1A
    print(f'A : mean={x1_time_mean0.mean()}, std={x1_time_mean0.std()}' )
    print(f'B : mean={x1_time_mean0_B.mean()}, std={x1_time_mean0.std()}' )

    x1_time_mean0_forx0 = x1_time_mean0
   


   
    # ########################## 
    x1_org_forx0 = torch.from_numpy(np.load( 
            data_params['X1']['A_path'][int(M*upsample)].replace('gt', 
                                                                 x0source, # 'am'
                                                                 ) 
        )).float().to( device=device)
    x1_time_forx0_mean0   = ( x1_org_forx0    - mu )  / std_x1A  
    x1_time_mean0_forx0 = x1_time_forx0_mean0
    # amp_way   = None
    # ########################## 




    sample_X0_data_obj = noise_x0_class( x1_time_mean0=x1_time_mean0_forx0, noise_amp=None, white_noise=white_noise,
                                         amp_way = amp_way
                                       )
    
   
    _ = sample_X0_data_obj.sample(n_samples=n_samples, upsample=upsample,
                        return_hat = True
                        )
  
   
    lam_k_K = sample_X0_data_obj.lam_k
    Phi = sample_X0_data_obj.Phi

    x1_org = np.load(data_params['X1']['A_path'][int(M*upsample)])

    data_dict = {
        'A': {'X1_data': x1_time_mean0[:n_samples]     
            },
        'B': {'X1_data': x1_time_mean0_B[:n_samples]  
            }
    }

    return data_dict,      sample_X0_data_obj, lam_k_K, Phi


    



################################################################################
# Create ffm model
################################################################################
def create_ffm(
        # Phi, lam_k_C, lam_k_K, 
               sample_X0_data_obj, 
               t_train_sampling_scheme, prediction, loss_lam_time,
               D, M, 
               
               modes,
  width,
  mlp_width,
  x_dim,

            **kwargs,

        ):
    # model
    model = FNO(modes,  
                vis_channels=D, 
                hidden_channels=width, 
                proj_channels=mlp_width, 
                x_dim=x_dim, 
                t_scaling=M, 
                )
    
    model = model.to(device) 
    
    fmot = FFMModel(model=model, 
                    D=D, 
                    device=device, 
                    sample_X0_data_func = sample_X0_data_obj, 
                    
                    x_dim=x_dim,

                    prediction=prediction,
                    t_train_sampling_scheme=t_train_sampling_scheme,

                    loss_lam_time=loss_lam_time,
            
                    )
    
    return fmot



def trainer(data_params, nn_params,    
   upsample_gen,


  n_samples_train,
  n_samples_gen, 
 
  t_train_sampling_scheme,
  prediction,
  loss_lam_time,

  
  batch_size,
  num_iterations,

   
  ema_opti,


  sfolder,


  epochs_A=-1,
  epochs_B=-1,

  lr=1e-4,
  lr_sch_step =  50, 
  lr_gamma = 0.1, 
 
        
         ):
    
   

    #########################################################################################
    # Data generation phase -- for Train
    data_dict,      sample_X0_data_obj, lam_k_K, Phi = getX0X1(
        data_params=data_params, 

        n_samples=n_samples_train, 
        upsample=1, 
      
    )
    # Done
    ###########################################################################################


    ################################################################################
    # Train or Generate
    ################################################################################
    results = {} 
    for id in ['A', 'B']: 
        X1_data = data_dict[id]['X1_data']
        results[id] = { 
            'y_train': X1_data, 
        } 
    
 
    n_samples_train_real = data_dict[id]['X1_data'].shape[0]
    steps_per_epoch = math.ceil(n_samples_train_real / batch_size) 
    epochs = math.ceil(num_iterations / steps_per_epoch)  # 1000 # 300 
    print(f'training epochs: {epochs}, steps_per_epoch: {steps_per_epoch}') 

    
    for id, result_dict in results.items():
        print(f'generate training data {id}')
 
        
        # training dataset 
        y_train  = result_dict['y_train'].to(dtype=torch.double)
       
        
        dataset_tr = TensorDataset(y_train) # , y_train_mean_hat

   

        loader_tr  = DataLoader(
            dataset_tr,
            batch_size=batch_size,
            shuffle=True
        )
 

        if id == 'A':
            loader_val = copy.deepcopy(loader_tr)


    
        # model
        fmot = create_ffm(
           sample_X0_data_obj, 
               t_train_sampling_scheme, prediction, loss_lam_time,

     
                **data_params['general'],
                **nn_params, 
                )
 
        ## save path
        spath = Path(f'../{sfolder}/{id}')
        spath.mkdir(parents=True, exist_ok=True)

        os.makedirs(spath / 'imgs', exist_ok=True)
        os.makedirs(spath / 'ckpt', exist_ok=True)

        if id == 'A':
            epochs_selected = epochs_A if epochs_A != -1 else epochs   
            epoch_path = Path(spath / f'ckpt/epoch_{epochs_selected}.pt' )


        elif id == 'B':
            epochs_selected = epochs_B if epochs_B != -1 else epochs 
            epoch_path = Path(spath / f'ckpt/epoch_{epochs_selected}.pt')  

         
          
        print(id, epoch_path) 


        if not epoch_path.exists(): 
            

            if id == 'A': 
                if ( 'am' not in str(sfolder) )  :  
                    continue
            print(sfolder)
            print(f'to train {id}')


            ti_methods = str(sfolder).split("/")[-1].split("_")[:2]
            if id == 'A':
                ti_method = ti_methods[0]
            elif id == 'B':
                ti_method = ti_methods[1]


            run1 = wandb.init(
                project=str(sfolder).split("/")[2],
                name=ti_method,
                config={"ti_id": id, "ti_name": ti_method,  
                         "epochs": epochs_selected},
                reinit=True,
            )
            

            # train
            optimizer = optim.Adam(fmot.model.parameters(), 
                                   lr=lr, # lr=1e-4 
                                   weight_decay=0.0,
                                   )  
            scheduler = optim.lr_scheduler.StepLR(optimizer, 
                                                  step_size=lr_sch_step  ,
                                                  gamma=lr_gamma,
                                                  ) 
            if ema_opti:
                optimizer = EMA(optimizer, ema_decay=0.999)

            ## train!
            fmot.train(loader_tr, 
                        optimizer=optimizer, 
                        epochs=epochs_selected, 
                        scheduler=scheduler,

                    eval_int=  25, 
                    generate=n_samples_gen,  

                    save_path=spath,  

                    test_loader = loader_val
                    
                    ) 

            wandb.finish()
            

        # else:
 
        fmot.model = load_checkpoint(epoch_path, fmot.model)[0]
        fmot.model.eval() 
        print('load saved ckpt --- done!') 

        for upsample in list((upsample_gen)): 

            pdf_path = spath / 'imgs' /  f'epoch_{epochs_selected}_gen_upsample{upsample}.pdf'
    
            if   not pdf_path.exists():   
                print(f'to gen, upsample = {upsample}')

                
                y_gen, x0 = fmot.sample(  
                                    n_samples= n_samples_gen,  
                                    upsample = upsample,
                                    return_x0=True)  
                
                plot_real_vs_fake(
                                y_real=y_train, 
                                y_fake=y_gen, 
                                save_path= pdf_path,

                                plot_samples=  n_samples_gen ,
                                )
                

                print(x0.shape, y_gen.shape, y_train.shape)
                plot_spectrum_comparison(x1_time=y_train, x1_time_gen=y_gen, x0_time=x0, 
                                         x1_time_A = results['A']['y_train'], 
                             save_path=spath / 'imgs' /  f'epoch_{epochs_selected}_energyspectrum_upsample{upsample}.pdf')

             
               

        results[id].update({
            'fmot': fmot,
            'epoch_path': epoch_path,
        })

    

    return results 


 

def kl_estimate(data_params, 
                
                results,
                sfolder,
            
                n_t,  
                t_kl_sampling_scheme,

                n_samples_kl,

                upsample_kl,


                vdiff_source,
                v_source,

           
                 
                # **kwargs
                ):   
    
    ################################################################################
    # Compute KL
    ################################################################################

    KL_est_mean_dict = {}
    for upsample in list((upsample_kl)): # set  [1] + 
        KL_est_mean_dict[upsample] = {}
    
        print(f'to estimate kl, upsample = {upsample}')


        ###########################################################################################
        # Data generation phase -- for KL
        data_dict,      sample_X0_data_obj, lam_k_K, Phi = getX0X1(
            data_params=data_params, 

            n_samples=n_samples_kl, 
            upsample=upsample,
        )
        X1_data_A = data_dict['A']['X1_data']
        # Done
        ###########################################################################################
 
        for KL_direction, X1_data_A_KL in  { 
            'forward_KL': data_dict['A']['X1_data'],
            'reverse_KL': data_dict['B']['X1_data'] 
                              }.items() : 
            ###########################################################################################
            # KL estimation phase

            functional_kl_obj = FKLModel( 
                ffm_model_A = results['A']['fmot'] , 
                ffm_model_B = results['B']['fmot'] ,
            )

            
            print('kl_estimate_func: kl_estimate_gen')
            KL_est_mean, KL_est_std, KL_est = functional_kl_obj.kl_estimate_noisediffchannel( 
                        X1_data_A=X1_data_A_KL,  
                        
                        # ---- stats ----
                        upsample=upsample, 
                        Phi=Phi, 
                        # analytic_stats = [mean_hat_A, mean_hat_B, lam_k_C],
                        lam_k_K=lam_k_K, 
                                            
                        # ---- Integration over FFM's t ----
                        n_t = n_t, 
                        t_kl_sampling_scheme = t_kl_sampling_scheme,  

                        # ---- (diff of) v_t 's source ----
                        vdiff_source = vdiff_source, 
                        v_source     = v_source  ,

                        # ---- save check mse error plot ----
                        sfolder = sfolder,
                        )

            if vdiff_source == 'sample':
                print(f"sample-wise KL: {KL_est_mean:.3f} ± {KL_est_std:.3f}")

          
            KL_est_mean_dict[upsample][KL_direction] =  KL_est_mean
    

    return KL_est_mean_dict
        

 

def main(
      config,
        ):

    sfolder = config.get("trainer_params")['sfolder']

    results = trainer(
        data_params=config.get("data_params"), 
        nn_params=config.get("nn_params") ,

      
        
        **config.get("trainer_params"),    
         )
    
    KL_est_mean_dict = kl_estimate(data_params=config.get("data_params"), 
                results=results, 
                sfolder=sfolder,
                **config.get("kl_params"),  
    )

    config['kl_result'] = KL_est_mean_dict

    save_file = f"../{sfolder}/config_kl_FINAL.yaml"

    with open(save_file , "w") as f:
        yaml.safe_dump(
            config,
            f,
            sort_keys=False,     
            default_flow_style=False
        )
    print('result saved to ',  save_file)
 
if __name__ == "__main__":

    seed_everything(42)
 
  
    # Setup Logger
    logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")

    # Setup argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config", type=str,  default="/home/wangch/projects/functional_flow_matching/configs/toycase_sde_ESTKL_CASE2.yaml", help="Path to yaml config")
    # parser.add_argument("--logdir", type=str, default="./logs", help="Path to results dir")
    args, _ = parser.parse_known_args()

    with open(args.config, "r") as fp:
        config = yaml.safe_load(fp)


 
 
    if not config['data_params']['X0_GP_matern']['white_noise']:
        config['trainer_params']['sfolder'] += f"_X0rougher"
    else:
        config['trainer_params']['sfolder'] += '_X0white'
 
    
    config['trainer_params']['sfolder'] += "_losst"+str(config['trainer_params']['loss_lam_time'])


    if config['trainer_params']['t_train_sampling_scheme'] == "importance_sampling_t/(1-t)" :
        config['trainer_params']['sfolder'] += "_traintIStail"
    elif config['trainer_params']['t_train_sampling_scheme'] == "importance_sampling_t*(1-t)" :
        config['trainer_params']['sfolder'] += "_traintISmid"
    elif config['trainer_params']['t_train_sampling_scheme'] == "uniform" :
        config['trainer_params']['sfolder'] += "_traintUniform"
    elif config['trainer_params']['t_train_sampling_scheme'] == "logit_normal":
        config['trainer_params']['sfolder'] += "_traintLogit_normal"

 

    spath = Path(f"../{config['trainer_params']['sfolder']}")
    

    print(spath)
    # breakpoint()
    spath.mkdir(parents=True, exist_ok=True)



    save_file = spath / "config_pretrain.yaml"
    if not save_file.exists(): 
        with open(save_file , "w") as f:
            yaml.safe_dump(
                config,
                f,
                sort_keys=False,      # keep your key order
                default_flow_style=False
            )
        print('config_pretrain saved to ',  save_file)

 
    main(config)