import numpy as np
import os, sys, time, shutil, random
from typing import Optional, Tuple, Union
import argparse
import torch

import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.distributed as dist

from torch.utils.data import Dataset
from torch.nn.parallel import DistributedDataParallel

from tqdm import tqdm 
from ruamel.yaml import YAML
import torch.optim as optim
from torch.optim import lr_scheduler
from collections import OrderedDict
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts # pip install 'git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup'

from fm4npp.utils import *
from fm4npp.datasets.dataset import *
from fm4npp.models.mambagpt import MambaGPT

            
class Trainer():
    """ trainer class """
    def __init__(self, params, args):

        
        
        ''' init vars for distributed training (ddp) and logging'''
        self.root_dir = args.root_dir
        self.global_log_dir = os.path.join(args.root_dir, args.global_log_dir)
        self.config = args.config 
        self.run_num = args.run_num
        self.world_size = 1
        
        if 'WORLD_SIZE' in os.environ:
            self.world_size = int(os.environ['WORLD_SIZE'])

        self.local_rank = 0
        self.world_rank = 0
        
        if self.world_size > 1: # multigpu, use DDP with standard NCCL backend for communication routines
            dist.init_process_group(backend='nccl',
                                    init_method='env://')
            self.world_rank = dist.get_rank()
            self.local_rank = int(os.environ["LOCAL_RANK"])

        if torch.cuda.is_available():
            torch.cuda.set_device(self.local_rank)
            torch.backends.cudnn.benchmark = True

        self.log_to_screen = (self.world_rank==0)
        if torch.cuda.is_available():
            self.device = torch.cuda.current_device()
        else:
            self.device = torch.device('cpu')
        
        self.params = params
        print("running on rank {} with world size {}".format(self.world_rank, self.world_size))



    def init_exp_dir(self, exp_dir):
        # If finisher detected, stop the code and finish.
        self.finisher = os.path.join(exp_dir, 'finished.txt')
        
        if os.path.exists(self.finisher):
            raise FinishedTrainingError
            
        if self.world_rank==0:
            if not os.path.isdir(exp_dir):
                os.makedirs(exp_dir)
                os.makedirs(os.path.join(exp_dir, 'checkpoints/'))
                
        self.params['experiment_dir'] = os.path.abspath(exp_dir)
        self.params['checkpoint_path'] = os.path.join(exp_dir, 'checkpoints/ckpt.tar')

        if self.params.continue_from_best:
            self.params['checkpoint_path'] = os.path.join(exp_dir, 'checkpoints/ckpt_best.tar')

        self.params['resuming'] = True if os.path.isfile(self.params.checkpoint_path) else False
        idx = 0
        logfile = os.path.join(exp_dir, 'performance{}.log'.format(idx))
        
        if self.world_rank==0:    
            while os.path.exists(logfile):
                idx += 1
                logfile = os.path.join(exp_dir, 'performance{}.log'.format(idx))
        if dist.is_initialized():
            dist.barrier()
        
        self.logfile = logfile

        
        if self.world_rank==0:            
            with open(self.logfile, 'w') as f:
                f.write('Initialized at: {}\n'.format(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))))
    
            # Preparing global log directory
            if not os.path.isdir(self.global_log_dir):
                os.makedirs(self.global_log_dir)   
                
            self.globalfile = os.path.join(self.global_log_dir, 
                                           'config_{}_run_{}_{}.csv'.format(
                                               self.config,
                                               self.run_num, 
                                               self.parse_exp_details(self.params.params, 
                                                                      partial = ['data_version', 
                                                                                 'limit_size',
                                                                                 'model_version'],
                                                                      globalfile=True)
                                           ))
            print(self.globalfile)

        if dist.is_initialized():
            dist.barrier()
        
        if self.world_rank == 0 and not os.path.exists(self.globalfile):
            with open(self.globalfile, 'w') as f:
                pass
        if dist.is_initialized():
            dist.barrier()  
        
    def log_infile(self, log):
        with open(self.logfile, "a") as f:
            f.write("{}\n".format(log))

    def log_globalfile(self, split, step, loss, lr):
        with open(self.globalfile, "a") as f:
            f.write("{},{},{},{}\n".format(split, step, loss, lr))

    def finish_training(self):
        with open(self.finisher, 'w') as f:
            f.write(' ')
        raise FinishedTrainingError
    
    def parse_exp_details(self, D, partial=None, globalfile = False):
        """
        D: a dictionary listing parameters
        partial: a list of columns of interest
        """
        
        if globalfile:
            if partial is None:
                out = '|'.join(['{}:{}'.format(a, b) for a,b in D.items()])
            else:
                out = '|'.join(['{}:{}'.format(a, b) for a,b in D.items() if a in partial])
        else:
            out = 'Important Details:\n' + ''.join(['{}: {}\n'.format(a, b) for a,b in D.items()])
        return out

    def get_bin_index(self, seq_length):
        """Find the appropriate bin for a given sequence length."""
        for i in range(len(self.bins) - 1):
            if self.bins[i] <= seq_length < self.bins[i + 1]:
                return i
        return len(self.bins) - 2  # Assign to last bin if out of range

    def update_moving_average(self, bin_idx, loss_value):
        """Update exponential moving average of loss per bin."""
        self.loss_moving_avg[bin_idx] = (
            self.smoothing_factor * self.loss_moving_avg[bin_idx] +
            (1 - self.smoothing_factor) * loss_value
        )

    def compute_inverse_loss_weights(self):
        """Compute inverse loss weights for each bin."""
        weights = {i: 1 / (self.loss_moving_avg[i] + self.epsilon) for i in self.loss_moving_avg}
        total_weight = sum(weights.values())
        return {i: weights[i] / total_weight for i in weights}  # Normalize weights

    def launch(self):
        exp_dir = os.path.join(*[self.root_dir, self.config, self.run_num])
        self.init_exp_dir(exp_dir)

        self.params['global_batch_size'] = self.params.batch_size
        self.params['local_batch_size'] = int(self.params.batch_size//self.world_size)
        self.params['global_valid_batch_size'] = self.params.valid_batch_size
        self.params['local_valid_batch_size'] = int(self.params.valid_batch_size//self.world_size)

        self.log_infile(self.parse_exp_details(self.params.params))
        # 'train'/'val', step, loss
        # self.log_globalfile('train', self.step, loss)
        



        # get the model
        self.klen = self.params.klen
        self.model = MambaGPT(embed_dim=self.params.embed_dim, 
                      num_layers=self.params.num_layers_backbone, 
                      d_state=self.params.d_state, 
                      d_conv=4, 
                      expand=2, 
                      klen=self.klen, 
                      dropout=self.params.dropout,
                      embed_method=self.params.embed_method,
                      pe_method=self.params.pe_method
                     )

        def initialize_mamba2(model, d_state, embed_dim):
            """ Properly initializes Mamba v2 to ensure stable learning. """
            
            with torch.no_grad():
                for name, param in model.named_parameters():

                    if "lin_B" in name:
                        param.normal_(mean=0.0, std=(d_state / embed_dim)**0.5)

                    elif "lin_C" in name:
                        param.normal_(mean=0.0, std=(1.0 / (embed_dim*d_state))**0.5)

                    elif "norm.weight" in name:
                        init.ones_(param)

                    # Bias Terms
                    elif "bias" in name:
                        init.zeros_(param)

                print(f"✅ Mamba v2 Model Initialized")
                
        Nu = self.params.embed_dim
        Nx = self.params.d_state
        
        initialize_mamba2(self.model, Nx, Nu)      
        
        self.mup_width_multiplier = 1.0
        print('mup: ', self.mup_width_multiplier)
        
        self.model = self.model.to(self.device)
        print('Nparams: ', count_parameters(self.model))

        # distributed wrapper for data parallel
        if dist.is_initialized():
            self.model = DistributedDataParallel(self.model,
                                                device_ids=[self.local_rank],
                                                output_device=[self.local_rank],
                                                find_unused_parameters=True)
        # set an optimizer and learning rate scheduler   
        params_a   = []
        params_b   = []
        params_c   = []
        params_else= []

        for name, p in self.model.named_parameters():
            if "A_log" in name:
                params_a.append(p)   # might do LR ~ Nu
            elif "lin_B" in name:
                params_b.append(p)   # might do LR ~ Nx / sqrt(Nu)
            elif "lin_C" in name:
                params_c.append(p)   # might do LR ~ sqrt(Nu) / Nx
            else:
                params_else.append(p)
                
        self.optimizer = torch.optim.AdamW([
            {"params": params_a,   "lr": self.params.min_lr * Nu},                   # e.g. for A
            {"params": params_b,   "lr": self.params.min_lr * Nx / (Nu**0.5)},       # e.g. for B
            {"params": params_c,   "lr": self.params.min_lr * (Nu**0.5) / Nx},       # e.g. for C
            {"params": params_else,"lr": self.params.min_lr},
        ], weight_decay=0.1, betas=(0.9, 0.95))

        self.scaler = torch.amp.GradScaler('cuda') 
        
        self.scheduler = CosineAnnealingWarmupRestarts(self.optimizer,
                                          first_cycle_steps=self.params.total_steps,
                                          max_lr=self.params.max_lr,
                                          min_lr=self.params.min_lr,
                                          warmup_steps=self.params.warmup_steps)

        
        # get the dataloaders
        self.train_data_loader, self.train_sampler, self.val_data_loader, _ = get_data_loader(self.params, 
                                                                                              dist.is_initialized())

        # set loss functions
        self.loss_func = nn.MSELoss(reduction='none')
        
        # self.loss_func = nn.L1Loss(reduction='none')
        self.loss_func_eval = nn.MSELoss(reduction='none')
        # self.loss_func = nn.HuberLoss(reduction='mean', delta=0.01)

        # checkpointing
        self.iters = 0
        self.startEpoch = 0
        self.resumed = False
        if self.params.resuming:
            print("Loading checkpoint %s"%self.params.checkpoint_path)
            self.restore_checkpoint(self.params.checkpoint_path)
            self.resumed = True
                    
        self.epoch = self.startEpoch
        self.logs = {}

        # launch training
        self.train()

    def train(self):
        ###%%%%%%%
        # Debugging
        self.fwd_hooks = register_fine_grained_forward_hooks(self.model)
        self.bwd_hooks = register_param_backward_nan_hooks(self.model)
        ###%%%%%%%%

        
        if self.log_to_screen:
            print("Starting training loop...")
     
        self.best_loss = np.inf

        self.loss_bin = pickle_load('{}/loss_bin_pp.pkl'.format(self.params.stat_dir))
        self.loss_weight = pickle_load('{}/loss_weight_pp.pkl'.format(self.params.stat_dir))
        
        for epoch in range(self.startEpoch, self.params.max_epochs):
            
            self.epoch = epoch
            if dist.is_initialized():
                # shuffles data before every epoch
                self.train_sampler.set_epoch(epoch)
                
            self.resumed = False
                
            self.starttime = time.time()

            # training
            self.train_one_epoch()
                    
    def report_loss(self, loss_, dist):
        step_loss = torch.zeros((1), dtype=torch.float32, device=self.device)
        step_loss += loss_.detach()

        if dist.is_initialized():
            dist.all_reduce(step_loss)
            loss_log = float(step_loss.item()/dist.get_world_size())
        else:
            loss_log = step_loss.item()
        return loss_log

    def set_portion_condition(self, tmask, portion = 0.2):
        """tmask: a mask showing effective (i.e., non-padding area) region as 1"""
        total = tmask.sum(-1)
        condidx = torch.ceil(total * portion).long()        
        index_tensor = torch.arange(tmask.size(1)).expand(tmask.size(0), -1).to(tmask.device)  # Shape (B, N)
        newmask = (index_tensor < condidx.unsqueeze(1)).float()
        return newmask.bool()
        
    
    def train_one_epoch(self):
        tr_time = 0
        self.model.train()

        # Buffers for logs
        tr_start = time.time()
        start_idx = 0
        for i, (grouped, _, knearest) in enumerate(self.train_data_loader):
            self.iters += 1
            b, c = grouped.size(0), grouped.size(-1)
            targets = grouped.reshape(b, -1, 4)[:, :, 1:].to(self.device)
            klabel = knearest.reshape(b, -1, self.klen * 3).to(self.device)
            grouped = grouped.reshape(b, -1, c).to(self.device)

            self.model.zero_grad()

            # with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            point_pred = self.model(grouped)
            kpred = point_pred
            kmask = klabel != -100  # B X N X C
            tmask = targets[..., 0] != -100
            
            loss = self.loss_func(kpred, klabel) # B X N X C   # Take mean over the batch
            loss = (loss * kmask).sum(-1).sum(-1) / kmask.sum(-1).sum(-1) # Take average for sequence length first first

            # Weighting loss by the bin number. This compensate for inherent difficulty.
            loss_weight_ = apply_bin_weights_torch(
                torch.Tensor(self.loss_bin).to(self.device), 
                torch.Tensor(self.loss_weight).to(self.device), 
                tmask.sum(-1)
            )
            # print(loss_weight_, loss, loss*loss_weight_, tmask.sum(-1))
            loss = loss * loss_weight_
            loss = loss.mean()
            
            # Scale the loss
            self.scaler.scale(loss).backward()           
            self.scaler.unscale_(self.optimizer)  # Unscale before clipping

            # Clip gradients
            # max_norm = 1
            # grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=max_norm)
            grad_norm = torch.zeros(1)
            clip_value = self.params.grad_clip_value
            torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=clip_value)

            self.scaler.step(self.optimizer)
            self.scaler.update()

            self.scheduler.step()
            
            loss_log = '{},{},{},{},{}'.format(self.report_loss(loss, dist), 
                                            grad_norm.item(), 
                                            tmask.sum(-1).float().mean().item(), 
                                            tmask.sum(-1).float().std().item(),
                                            check_model_parameters(self.model) if self.iters % 100 == 0 or torch.isnan(loss) else ''
                                              )

            if self.world_rank == 0:
                self.log_globalfile('train', self.iters, loss_log, self.scheduler.get_lr()[0])  

            # Every n_eval_steps
            if self.iters % self.params.n_eval_steps == 0:
                tr_time += time.time() - tr_start
                self.val_one_epoch(tr_time)
                tr_start = time.time()

        return 0

    
    def val_one_epoch(self, tr_time):
        self.model.eval()
        val_start = time.time()

        logs_buff = torch.zeros((1), dtype=torch.float32, device=self.device)
        self.logs['val_loss'] = logs_buff[0].view(-1)
        start_idx = 0
        with torch.no_grad():
            for i, (grouped, _, knearest) in enumerate(self.val_data_loader):
                b, c = grouped.size(0), grouped.size(-1)
                targets = grouped.reshape(b, -1, 4)[:, :, 1:].to(self.device)
                klabel = knearest.reshape(b, -1, self.klen * 3).to(self.device)
                grouped = grouped.reshape(b, -1, c).to(self.device)
    
                self.model.zero_grad()
    
                # with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                point_pred = self.model(grouped)
                kpred = point_pred
                kmask = klabel != -100  # B X N X C
                tmask = targets[..., 0] != -100
                loss_kpred = self.loss_func_eval(kpred[kmask], klabel[kmask]).mean()         
        

                loss = loss_kpred
                    
                self.logs['val_loss'] += loss.detach()

        self.logs['val_loss'] /= len(self.val_data_loader)
        if dist.is_initialized():
            for key in ['val_loss']:
                dist.all_reduce(self.logs[key].detach())
                self.logs[key] = self.logs[key]/dist.get_world_size()

        val_time = time.time() - val_start
        
        # keep track of best model according to validation loss
        if self.logs['val_loss'] <= self.best_loss:
            is_best_loss = True
            self.best_loss = self.logs['val_loss']
        else:
            is_best_loss = False

        # save checkpoint (if best epoch additionally save the best epoch too)
        if self.params.save_checkpoint:
            if self.world_rank == 0:
                #checkpoint at the end of every "save_step" steps
                self.save_checkpoint(self.params.checkpoint_path, is_best=is_best_loss)

        # some print statements
        tolog = 'Time taken {:.2f} sec; with {:.2f} / {:.2f} in tr/val\n'.format(time.time()-self.starttime, tr_time, val_time)
        tolog += 'Step = {}, Val loss = {}'.format(self.iters, float(self.logs['val_loss']))

        if self.world_rank == 0 and self.log_to_screen:
            print(tolog)

        # Logging results
        if self.world_rank == 0:
            self.log_infile(tolog)
            self.log_globalfile('val', self.iters, float(self.logs['val_loss']), self.scheduler.get_lr()[0])  
        
        # If total step is reached, finish the training
        if self.iters >= self.params['total_steps']:
            self.finish_training()

        self.model.train()
        return 0

    def save_checkpoint(self, checkpoint_path, is_best=False):
        state_dict = self.model.state_dict()
            
        torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': state_dict, 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': (self.scheduler.state_dict() if self.scheduler is not None else None)}, checkpoint_path)
        if is_best:
            torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': state_dict, 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': (self.scheduler.state_dict() if  self.scheduler is not None else None)}, checkpoint_path.replace('.tar', '_best.tar'))

            
    def restore_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank), weights_only=False) 
        try:
            self.model.load_state_dict(checkpoint['model_state'])
        except:
            new_state_dict = OrderedDict()
            for key, val in checkpoint['model_state'].items():
                name = key[7:]
                new_state_dict[name] = val 
            self.model.load_state_dict(new_state_dict)

        self.iters = checkpoint['iters']
        self.startEpoch = checkpoint['epoch']+1 if self.iters % len(self.train_data_loader) == 0 else checkpoint['epoch']
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if self.scheduler is not None:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--yaml_config", required=True, type=str)
    parser.add_argument("--config", required=True, type=str)
    parser.add_argument("--run_num", default='0', type=str, help='sub run number')
    parser.add_argument("--root_dir", default='<ROOT_DIR_SPECIFY>', type=str, help='root dir to store results')
    parser.add_argument("--global_log_dir", default='globallogs', type=str, help='global dir to store logging only')
    
    args = parser.parse_args()
    params = YParams(os.path.abspath(args.yaml_config), args.config)
    trainer = Trainer(params, args)
    trainer.launch()
    
    if dist.is_initialized():
        dist.barrier()

    print('Training complete')
    dist.destroy_process_group()