
import os, time, gc
import numpy as np

import torch
import torch.nn.functional as F
from torch.optim import SGD, RMSprop, Adagrad, AdamW, lr_scheduler, Adam
from torch.utils.tensorboard import SummaryWriter
from torch_ema import ExponentialMovingAverage
from metrics import MMD_loss,compute_metrics,metric_build
import policy
import sde
from loss import compute_sb_DSB_train
import data
import util

from ipdb import set_trace as debug
from einops import rearrange, repeat

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import imageio

def build_optimizer_ema_sched(opt, policy):
    direction = policy.direction

    optim_name = {
        'Adam': Adam,
        'AdamW': AdamW,
        'Adagrad': Adagrad,
        'RMSprop': RMSprop,
        'SGD': SGD,
    }.get(opt.optimizer)

    optim_dict = {
            "lr": opt.lr_f if direction=='forward' else opt.lr_b,
            'weight_decay':opt.l2_norm,
    }
    if opt.optimizer == 'SGD':
        optim_dict['momentum'] = 0.9

    optimizer   = optim_name(policy.parameters(), **optim_dict)
    ema         = ExponentialMovingAverage(policy.parameters(), decay=0.999)
    if opt.lr_gamma < 1.0:
        sched = lr_scheduler.StepLR(optimizer, step_size=opt.lr_step, gamma=opt.lr_gamma)
    else:
        sched = None

    return optimizer, ema, sched

def freeze_policy(policy):
    for p in policy.parameters():
        p.requires_grad = False
    policy.eval()
    return policy

def activate_policy(policy):
    for p in policy.parameters():
        p.requires_grad = True
    policy.train()
    return policy

class Runner():
    def __init__(self,opt):
        super(Runner,self).__init__()

        self.opt = opt
        self.start_time = time.time()
        
        if opt.problem_name == "RNA5dim":
            dists = data.build(opt)
            self.x_dists  = dists[0]
            self.val_dists = [dist for dist in dists[1]]
            self.test_dists = [dist for dist in dists[-1]]        
        elif opt.problem_name == "hesc":
            self.x_dists, self.gt = data.build(opt)    
            # opt.interval += 1
        else:
            self.x_dists  = data.build(opt)
            
        self.num_dist = len(self.x_dists)

        self.t_dists  = torch.linspace(0, self.num_dist-1, self.num_dist)
        self.ts = torch.linspace(opt.t0, opt.T, opt.interval)
        self.dt = self.ts[1] - self.ts[0]
        
        self.eval_ts = [torch.linspace(self.t_dists[i], self.t_dists[i+1], opt.interval) for i in range(self.num_dist - 1)]
        self.eval_ts = torch.unique(torch.hstack(self.eval_ts))
        
        # for visualize training data
        if opt.problem_name =='RNAsc':
            self.x_data = [dist.ground_truth for dist in self.x_dists] 
            
        self.xs_data_ = [self.x_dists[i].sample()[:, None] for i in range(self.num_dist)]
        self.xs_data = torch.hstack(self.xs_data_)
        
        if self.opt.LOO != -1:
            assert not(opt.LOO==0 or opt.LOO>=self.num_dist-1)
            indices = [i for i in range(self.num_dist) if i != self.opt.LOO]
            self.indices = torch.tensor(indices).to(opt.device)
            self.t_dists  = self.t_dists[self.indices]
            
            
        # Build metrics
        self.metrics    = metric_build(opt)
        # build dynamics, forward (z_f) and backward (z_b) policies and corresponding optimizer
        self.dyn        = sde.build(opt, self.x_dists)
        self.dyn.dt = self.dt.cpu()

        self.z_f        = policy.build(opt, self.dyn, 'forward')  # p -> q
        self.z_b        = policy.build(opt, self.dyn, 'backward') # q -> p

        self.optimizer_f, self.ema_f, self.sched_f = build_optimizer_ema_sched(opt, self.z_f)
        self.optimizer_b, self.ema_b, self.sched_b = build_optimizer_ema_sched(opt, self.z_b)


        if opt.load:
            util.restore_checkpoint(opt, self, opt.load)

    def update_count(self, direction):
        if direction == 'forward':
            self.it_f += 1
            return self.it_f
        elif direction == 'backward':
            self.it_b += 1
            return self.it_b
        else:
            raise RuntimeError()

    def get_optimizer_ema_sched(self, z):
        if z == self.z_f:
            return self.optimizer_f, self.ema_f, self.sched_f
        elif z == self.z_b:
            return self.optimizer_b, self.ema_b, self.sched_b
        else:
            raise RuntimeError()

    @torch.no_grad()
    def evaluate(self, opt, stage, direction):
        policy_impt = {
            'forward': self.z_f, # sample from forward
            'backward': self.z_b, # sample from backward
        }.get(direction)

        t0s = self.t_dists[0][None]
        t0s = repeat(t0s, "n ->n b", b=opt.samp_bs)
        
        xs = self.xs_data
        
        if opt.load is None and direction =='forward':
            test=False
            if stage == opt.num_stage or stage == opt.num_stage - 1 :
                keys = ['z_f','optimizer_f','ema_f','z_b','optimizer_b','ema_b']
                util.save_checkpoint(opt, self, keys, stage)
        else:
            test=True
            
        _, ema_impt, _      = self.get_optimizer_ema_sched(policy_impt)
        with ema_impt.average_parameters():
            policy_impt     = freeze_policy(policy_impt)
            
            if direction == "forward": # sample forward
                
                if test is True:
                    if self.opt.problem_name == 'RNAsc':
                        print("Test || RNAsc")
                        
                        init_samples = torch.Tensor(self.x_dists[0].test_sample)
                        init_samples = init_samples.repeat(
                            int(self.opt.samp_bs/init_samples.shape[0]) + 1, 1)
                        init_samples = init_samples[0:self.opt.samp_bs, None, ...]
                        init_times = t0s
                        # DMSB : we use whole data which is similar to FID computation
                        target_samples = self.x_data 
                        ts = self.eval_ts     
                    elif self.opt.problem_name == 'RNA5dim':
                        print("Test || RNA5dim")
                        init_samples = [self.test_dists[i][:, None] for i in range(self.num_dist-1)]
                        init_samples = torch.hstack(init_samples)
                        target_samples = self.test_dists                                         
                        t0s = self.t_dists[:-1]                        
                        sample_size = init_samples.shape[0]                        
                        init_times = repeat(t0s, "n ->n b", b=sample_size)                        
                        ts = self.ts
                    elif self.opt.problem_name == 'hesc':
                        init_samples = self.gt[0][:, None]
                        init_times = t0s[:, :init_samples.shape[0]]
                        ts = self.eval_ts
                        target_samples = self.gt
                    else:
                        init_samples = xs[:, :1, :]
                        init_times = t0s
                        ts = self.eval_ts     
                        target_samples = None 
                        
                else:
                    if self.opt.problem_name == 'RNA5dim':
                        init_samples = [self.test_dists[i][:, None] for i in range(self.num_dist-1)]
                        init_samples = torch.hstack(init_samples)
                        target_samples = self.test_dists                                         
                        t0s = self.t_dists[:-1]                        
                        sample_size = init_samples.shape[0]                        
                        init_times = repeat(t0s, "n ->n b", b=sample_size)                        
                        ts = self.ts
                    elif self.opt.problem_name == 'hesc':
                        init_samples = self.gt[0][:, None]
                        init_times = t0s[:, :init_samples.shape[0]]
                        ts = self.eval_ts
                        target_samples = self.gt
                    elif self.opt.problem_name == 'RNAsc':
                        init_samples = xs[:, :1, :]
                        init_times = t0s               
                        ts = self.eval_ts    
                        
                        xxx = [xyz.squeeze(1).cpu().numpy() for xyz in self.xs_data_ ]
                        target_samples = self.x_data
                    else:
                        init_samples = xs[:, :1, :]
                        init_times = t0s               
                        ts = self.eval_ts    
                        target_samples = None
                    
                    
                sample_xs, _ = self.dyn.sample_traj(ts, policy_impt, init_samples, init_times)
                
                if self.opt.problem_name == "RNAsc":
                    compute_metrics(opt, sample_xs.cpu().numpy(), target_samples, self.metrics, self, stage, "forward", test)            
                        
                    fn = os.path.join(opt.forward_generated_data_path, 'forward_path_stage_{}'.format(stage))
                    self.save_trajectories_pdf(xs.cpu(), sample_xs.cpu(), self.eval_ts, 
                                            fn, direction=policy_impt.direction)
                    
                elif self.opt.problem_name == "RNA5dim":
                    sample_xs = rearrange(sample_xs[:, -1], '(n b) d -> b n d', b=sample_size)
                    compute_metrics(opt, sample_xs.cpu(), target_samples, self.metrics, self, stage, "forward", test)   
                elif self.opt.problem_name == "hesc":
                    compute_metrics(opt, sample_xs.cpu().numpy(), target_samples, self.metrics, self, stage, "forward", test)  
                    fn = os.path.join(opt.forward_generated_data_path, 'forward_path_stage_{}'.format(stage))
                    self.save_trajectories_pdf(self.gt, sample_xs.cpu(), self.eval_ts, 
                                            fn, direction=policy_impt.direction)
                else:
                    fn = os.path.join(opt.forward_generated_data_path, 'forward_path_stage_{}'.format(stage))
                    self.save_trajectories_pdf(xs.cpu(), sample_xs.cpu(), self.eval_ts, 
                                            fn, direction=policy_impt.direction)
                    
                                                          
    @torch.no_grad()
    def sample_msbm_coupling(self, opt, stage, direction, policy_opt, policy_impt):
        t0s, t1s = self.t_dists[:-1], self.t_dists[1:]
        
        xs = [
            self.x_dists[i].sample()[:, None]
            for i in (self.indices if self.opt.LOO != -1 else range(self.num_dist))]
        xs = torch.hstack(xs)
        
        train_t0s = repeat(t0s, "n ->n b", b=opt.samp_bs)
        train_t1s = repeat(t1s, "n ->n b", b=opt.samp_bs)
        
        if stage == 1:  
            train_x0s = xs[:, 1:, :]
            train_x1s = xs[:, :-1, :]
            train_xs = None
            
            train_t0s = train_t0s.flip(dims=[0])
            train_t1s = train_t1s.flip(dims=[0])

            print('generate train data from [{}]!'.format(util.red('Independent Coupling')))

        elif stage > 1:
            _, ema_impt, _      = self.get_optimizer_ema_sched(policy_impt)
            with ema_impt.average_parameters():
                policy_impt     = freeze_policy(policy_impt)
                
                if direction == "backward":  # train forward | sample backward
                    inits = xs[:, 1:, :]
                    train_xs, _ = self.dyn.sample_traj(self.ts, policy_impt, inits, train_t0s.flip(dims=[0]))

                    train_dists = rearrange(train_xs[:, -1], '(n b) d -> b n d', b=opt.samp_bs)
                    train_x0s = train_dists          
                    train_x1s = inits                        
                                            
                elif direction == "forward":  # train backward | sample forward
                    inits = xs[:, :-1, :]
                    train_xs, _  = self.dyn.sample_traj(self.ts, policy_impt, inits, train_t0s)
                    
                    train_dists = rearrange(train_xs[:, -1], '(n b) d -> b n d', b=opt.samp_bs)
                    train_x0s = train_dists 
                    train_x1s = inits

                    train_t0s = train_t0s.flip(dims=[0])
                    train_t1s = train_t1s.flip(dims=[0])

        else: return
        train_x0s = rearrange(train_x0s, 'b n d -> (n b) d')
        if train_x1s.shape[0] == opt.samp_bs:
            train_x1s = rearrange(train_x1s, 'b n d -> (n b) d')

        train_t0s = rearrange(train_t0s,  "n b -> (n b) 1")
        if train_t1s.shape[1] == opt.samp_bs:
            train_t1s = rearrange(train_t1s,  "n b -> (n b) 1")
        
        assert train_x0s.shape[0] == train_x1s.shape[0]
        assert train_t0s.shape[0] == train_t1s.shape[0]
        assert train_x0s.shape[0] == train_t0s.shape[0]
        
        gc.collect()
        
        return train_x0s, train_x1s, train_t0s, train_t1s

    def msbm_alternate_train(self, opt):
        
        self.save_samples_pdf(opt.problem_name)
        
        bridge_ep = opt.num_epoch
        # if opt.problem_name =='petal': bridge_ep = 1 #Special handle for petal. the distance between distributions are too close.
        for stage in range(1, opt.num_stage+1):
            if (stage % 2) != 0:
                self.msbm_alternate_train_stage(opt, stage, bridge_ep, 'forward')
                if stage > 1:

                    self.evaluate(opt, stage, 'forward')
                
            else:
                self.msbm_alternate_train_stage(opt, stage, bridge_ep, 'backward')


        
    def msbm_alternate_train_stage(self, opt, stage, epoch, direction):
        policy_opt, policy_impt = {
            'forward':  [self.z_b, self.z_f], # train backward, sample from forward
            'backward': [self.z_f, self.z_b], # train forward, sample from backward
        }.get(direction)

        for ep in range(epoch):
            # prepare training data
            train_x0s, train_x1s, train_t0s, train_t1s = self.sample_msbm_coupling(
                opt, stage, direction, policy_opt, policy_impt)   
            
            # train one epoch
            policy_impt = freeze_policy(policy_impt)
            policy_opt = activate_policy(policy_opt)
            
            self.msbm_alternate_train_ep(
                opt, stage, ep, direction, 
                train_x0s, train_x1s, train_t0s, train_t1s, policy_opt
            )

    def msbm_alternate_train_ep(self, opt, stage, epoch, direction, 
                                train_x0s, train_x1s, train_t0s, train_t1s, policy):
        
        optimizer, ema, sched = self.get_optimizer_ema_sched(policy)
        use_amp = opt.use_amp
        scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
        
        for it in range(opt.num_itr):
            # -------- sample x_idx and t_idx \in [0, interval] --------
            samp_x_idx = torch.randint(opt.samp_bs * (len(self.t_dists)-1),  (opt.train_bs_x,),device='cpu')
            
            train_x0s_sample, train_x1s_sample = train_x0s[samp_x_idx], train_x1s[samp_x_idx]
            train_t0s_sample, train_t1s_sample = train_t0s[samp_x_idx], train_t1s[samp_x_idx]
            
            t_T = (train_t1s_sample - train_t0s_sample) - self.dt * 0.1

            train_t = (torch.rand(size=(opt.train_bs_x, 1)).to(opt.device) * t_T)
            
            train_t_sample = torch.cat([train_t0s_sample, train_t0s_sample + train_t, train_t1s_sample], dim=-1)
                
            train_xts_sample = self.dyn.sample_bridge(train_x0s_sample, train_x1s_sample, train_t_sample)
            train_xts_target = self.dyn.sample_target(train_x0s_sample, train_x1s_sample, train_t_sample)
            
            alpha_t = policy(train_xts_sample, train_t_sample[:, 1])
            
            loss = torch.mean((train_xts_target - alpha_t) ** 2)
            
            scaler.scale(loss).backward()
            if opt.grad_clip is not None:
                torch.nn.utils.clip_grad_norm(policy.parameters(), opt.grad_clip)
            
            scaler.step(optimizer)
            scaler.update()
            optimizer.step()
            ema.update()
            if sched is not None: sched.step()
            
            self.log_sb_alternate_train(opt, it, epoch, stage, loss, optimizer, direction, opt.num_epoch)

    def log_sb_alternate_train(self, opt, it, ep, stage, loss, optimizer, direction, num_epoch):
        time_elapsed = util.get_time(time.time()-self.start_time)
        lr = optimizer.param_groups[0]['lr']
        if (it+1)%1000==0:
            print("[{0}] stage {1}/{2} | ep {3}/{4} | train_it {5}/{6} | lr:{7} | loss:{8} | time:{9}"
                .format(
                    util.magenta("SB {} sampling".format(direction)),
                    util.cyan("{}".format(stage)),
                    opt.num_stage,
                    util.cyan("{}".format(1+ep)),
                    num_epoch,
                    util.cyan("{}".format(1+it+opt.num_itr*ep)),
                    opt.num_itr*num_epoch,
                    util.yellow("{:.2e}".format(lr)),
                    util.red("{:+.4f}".format(loss.item())),
                    util.green("{0}:{1:02d}:{2:05.2f}".format(*time_elapsed)),
            ))
