import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
import sys
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
sys.path.append(project_dir)
os.chdir(project_dir)
from datetime import datetime
from pyhocon import ConfigFactory
import numpy as np
import argparse
import json
import torch
from models import gradient
import utils.general as utils
from sampler.sample import Sampler
from utils.metric import MMD
from utils.plots import plot_2D
from tqdm import trange
    

class ReconstructionRunner:
    def run(self):
        print("running")
        tbar = trange(0, self.nepochs+1)
        for epoch in tbar:
            tbar.set_description_str(f"Epoch {epoch}/{self.nepochs}")
    
            input_data_indices = torch.tensor(np.random.choice(self.input_data.shape[0], self.batch_data, False))
            input_data_pnts = torch.tensor(self.input_data[input_data_indices,:], dtype=torch.float32, requires_grad=True)
            
            output_data_indices = torch.tensor(np.random.choice(self.output_data.shape[0], self.batch_data, False))
            output_data_pnts = torch.tensor(self.output_data[output_data_indices,:], dtype=torch.float32, requires_grad=True)

            pnts = self.random_sampler(N=self.batch_hj)

            if epoch > 0 and epoch % self.checkpoint_frequency == 0:
                print('saving checkpoint: ', epoch)
                self.save_checkpoints(epoch)

            if epoch > 0 and epoch % self.plot_frequency == 0:
                print('plot validation epoch: ', epoch)
                if self.d_in == 2:
                    plot_2D(self.network, self.input_data, self.output_data, self.T, file_name=os.path.join(self.expdir, f'results_ep{epoch}'), device=self.gpu)
            
            losses = self.train_epoch(epoch=epoch, input_data_samples=input_data_pnts, output_data_samples=output_data_pnts, pnts=pnts)
            # losses_SP.append(losses['MLP SP loss'])

            postfix = {k:f'{v:.6f}' for k, v in losses.items()}
            tbar.set_postfix(postfix)
            if epoch > 0 and epoch % self.status_frequency == 0:
                f = open(f'{self.cur_exp_dir}/logs.txt', 'a')
                log = f'Train Epoch: [{str(epoch).rjust(len(str(self.nepochs)))}/{self.nepochs} ({100*epoch/self.nepochs:3.0f}%)] | '
                for k, v in losses.items():
                    log += f'{k}: {v:.6f}\t'
                log = log[:-1] + '\n'
                f.write(log)
                f.close()
        
    
    def train_epoch(self, epoch, input_data_samples, output_data_samples, pnts):
        input_data_samples = input_data_samples.to(self.gpu)
        output_data_samples = output_data_samples.to(self.gpu)
        pnts = pnts.to(self.gpu)
        self.network.train()
        self.adjust_learning_rate(epoch, self.optimizer)

        # --------------------------------------------------------------------------------------
        # forward pass
        # --------------------------------------------------------------------------------------
        pred_sol = self.network(pnts) # u(t,x)
        grad_pred_sol = gradient(pnts,pred_sol)[:,1:] # du(t,x)/dx

        # --------------------------------------------------------------------------------------
        # Compute loss
        # --------------------------------------------------------------------------------------
        loss = torch.tensor(0.).to(self.gpu)
        losses = {'Train loss' : None}

        # Implicit HJ loss
        if 'implicithj' in self.regularizer_type:
            pred_sol = self.network(pnts)
            grad_pred_sol = gradient(pnts,pred_sol)[:,1:]
            init_x = pnts[:,1:] - pnts[:,[0]]*grad_pred_sol
            init_xt = torch.cat((torch.zeros((self.batch_hj,1)).to(self.gpu), init_x), 1)
            loss_implicithj = ((pred_sol - 0.5*pnts[:,[0]]*torch.sum(grad_pred_sol*grad_pred_sol,dim=1,keepdim=True) - self.network(init_xt))**2).mean()

            loss += self.regularizer_coord[self.regularizer_index['implicithj']] * loss_implicithj
            losses['Implicit HJ loss'] = loss_implicithj.item()
            
        # PINN loss
        if 'pinn' in self.regularizer_type:
            pred_sol = self.network(pnts)
            grad_pred_sol = gradient(pnts,pred_sol)[:,1:]
            loss_pde = ((grad_pred_sol[:,[0]] - 0.5*torch.sum(grad_pred_sol[:,1:]*grad_pred_sol[:,1:],dim=1,keepdim=True))**2).mean() 
        
            loss += self.regularizer_coord[self.regularizer_index['pinn']] * loss_pde
            losses['PINN HJ loss'] = loss_pde.item()
                
        # MMD Loss
        if 'mmd' in self.regularizer_type:
            init_spatialtemporal_pnts = torch.cat((torch.zeros((self.batch_data, 1), device=self.gpu), input_data_samples), dim=1).requires_grad_(True)  
            pred_sol = self.network(init_spatialtemporal_pnts)
            transported_pnts_to_T = init_spatialtemporal_pnts[:, 1:] + self.T * gradient(init_spatialtemporal_pnts, pred_sol)[:,1:]
            loss_MMD_T = MMD(transported_pnts_to_T, output_data_samples, kernel=self.MMD_kernel, k_sigma=self.MMD_sigma)

            spatialtemporal_pnts = torch.cat((self.T*torch.ones((self.batch_data, 1), device=self.gpu), output_data_samples), dim=1).requires_grad_(True)  
            pred_sol = self.network(spatialtemporal_pnts)
            transported_pnts_to_0 = spatialtemporal_pnts[:,1:] - self.T * gradient(spatialtemporal_pnts, pred_sol)[:,1:]
            loss_MMD_0 = MMD(transported_pnts_to_0, input_data_samples, kernel=self.MMD_kernel, k_sigma=self.MMD_sigma)


            loss_MMD = loss_MMD_0 + loss_MMD_T
            
    
            loss += self.regularizer_coord[self.regularizer_index['mmd']] * loss_MMD
            losses['MMD loss'] = loss_MMD.item()
    
        # Total Loss
        losses['Train loss'] = loss.item()
        self.optimizer.zero_grad()
        loss.backward(retain_graph=True)
        self.optimizer.step()

        return losses

    
    def __init__(self, args):    
        self.home_dir = args.home_dir
        self.conf_filename = args.conf
        self.conf = ConfigFactory.parse_file(self.conf_filename)
        self.data_dir = args.data_dir
        self.input = args.input.lower()
        self.output = args.output.lower()
        if args.expname == None:
            self.expname = f"OT_{self.input.split('.')[0]}_to_{self.output}"
        else:
            self.expname = args.expname
            
        utils.set_random_seed(5884)

        if torch.cuda.is_available() and args.gpu > -1:
            print(f'MLP in gpu {args.gpu}')
            self.gpu = torch.device(args.gpu)
        else:
            print('MLP in cpu')
            self.gpu = torch.device('cpu')

        # sampler config
        self.d_in = self.conf.get_int('train.d_in')            
        self.batch_hj = self.conf.get_int('sampler.batch_hj')
        self.batch_data = self.conf.get_int('sampler.batch_data')
        self.spatial_range = self.conf.get_float('sampler.spatial_range')
        self.T = self.conf.get_float('sampler.terminal_time')
        self.sampler_type = self.conf.get_string('sampler.sampler_type').lower()
        
      
        self.input_data = utils.load_data(os.path.join(self.data_dir, self.input)).requires_grad_()
        self.output_data = utils.load_data(os.path.join(self.data_dir, self.output)).requires_grad_()
        
        min_data_len = min(self.input_data.shape[0], self.output_data.shape[0])
        if min_data_len < self.batch_data:
            self.batch_data = min_data_len

        if self.sampler_type == 'random':

            self.dim = (self.input_data).shape[-1]
            self.xmin = torch.min(torch.cat([self.input_data, self.output_data], dim=0), dim=0).values
            self.xmax = torch.max(torch.cat([self.input_data, self.output_data], dim=0), dim=0).values
            self.random_sampler = Sampler(dim=self.dim, xmin=self.xmin, xmax=self.xmax, T=self.T)
            
        # train config
        self.nepochs = self.conf.get_int('train.nepochs')
        self.regularizer_type = list(map(str.lower, self.conf.get_list('train.regularizer_type')))
        self.regularizer_coord = self.conf.get_list('train.regularizer_coord')
        assert len(self.regularizer_type) == len(self.regularizer_coord), 'match regularizer coordinates'
        self.regularizer_index = {t:i for i, t in enumerate(self.regularizer_type)}
        
        if 'mmd' in self.regularizer_type:
            self.MMD_kernel = self.conf.get_string('train.MMD_kernel').lower()
            self.MMD_sigma = self.conf.get_int('train.MMD_sigma')
            
        
        self.checkpoint_frequency = self.conf.get_int('train.checkpoint_frequency')
        self.status_frequency = self.conf.get_int('train.status_frequency')
        self.plot_frequency = self.conf.get_int('train.plot_frequency')

        # make paths
        self.timestamp = f'{datetime.now():%Y_%m_%d_%H_%M_%S}'

        self.exps_folder_name = 'exps'
        utils.mkdir_ifnotexists(utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name)))

        self.expdir = utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name, self.expname))
        utils.mkdir_ifnotexists(self.expdir)

        self.cur_exp_dir = os.path.join(self.expdir, self.timestamp)
        utils.mkdir_ifnotexists(self.cur_exp_dir)
           
        self.plots_dir = os.path.join(self.cur_exp_dir, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.model_params_subdir = "ModelParameters"
        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir))
        self.optimizer_params_subdir = "OptimizerParameters"
        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
            
        # network and optimizer
        self.network_class = self.conf.get_string('model.network_class')
        self.weight_decay = self.conf.get_float('model.optimizer.weight_decay')
        self.lr_schedules = self.get_learning_rate_schedules(self.conf.get_list('model.optimizer.learning_rate_schedule'))
        self.network  = utils.get_class(self.network_class)(d_in=self.d_in+1, **self.conf.get_config('model.network_inputs')).to(self.gpu)
        self.optimizer = torch.optim.Adam(
            [
                {
                    "params": self.network.parameters(),
                    "lr": self.lr_schedules[0].get_learning_rate(0),
                    "weight_decay": self.weight_decay
                },
            ])


        # save config
        self.conf['data'] = self.input
        self.conf['expname'] = self.expname
        self.conf['train']['d_in'] = self.d_in
        self.conf['train']['checkpoint frequency'] = self.checkpoint_frequency
        self.conf['train']['plot frequency'] = self.plot_frequency
        self.conf['train']['status frequency'] = self.status_frequency
        with open(f'{self.cur_exp_dir}/config.json', 'w') as f:
            json.dump(self.conf, f, indent=4)

    def get_learning_rate_schedules(self, schedule_specs):
        schedules = []
        for schedule_specs in schedule_specs:

            if schedule_specs["Type"] == "Step":
                schedules.append(
                    utils.StepLearningRateSchedule(
                        schedule_specs["Initial"],
                        schedule_specs["Interval"],
                        schedule_specs["Factor"],
                    )
                )
            else:
                raise Exception(
                    'no known learning rate schedule of type "{}"'.format(
                        schedule_specs["Type"]
                    )
                )
        return schedules
    
    def adjust_learning_rate(self, epoch, optimizer):
        for i, param_group in enumerate(optimizer.param_groups):
            param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch)
    
    def save_checkpoints(self, epoch):
        torch.save(
            {"epoch": epoch, "model_sol_state_dict": self.network.state_dict()},
            os.path.join(self.checkpoints_path, self.model_params_subdir, f'{epoch}.pth'))

        torch.save(
            {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
            os.path.join(self.checkpoints_path, self.optimizer_params_subdir, f'{epoch}.pth'))


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--home_dir', type=str, default='/home/yesom/Codes/ImplicitHJ_OT')
    parser.add_argument('--conf', type=str, default='/home/yesom/Codes/ImplicitHJ_OT/setup.conf')
    parser.add_argument('--data_dir', type=str, default='/home/yesom/Codes/ImplicitHJ_OT/data/synthetic')
    parser.add_argument('--input', type=str, default='spiral.npy', help='data name')
    parser.add_argument('--output', type=str, default='two_moons.npy', help='data name')
    parser.add_argument('--expname', type=str, default=None)
    parser.add_argument('--gpu', type=int, default=0)
    args = parser.parse_args()

    trainrunner = ReconstructionRunner(args)
    try:
        trainrunner.run()
    except:
        import traceback
        message = traceback.format_exc()
        with open(f'{trainrunner.cur_exp_dir}/logs.txt', 'a') as f:
            f.write(message)