import yaml
from datetime import datetime
import os
import argparse
import logging
import torch
import numpy as np
from typing import Optional, Dict
from torch.utils.data import DataLoader
from torch import optim

from .expdata_config import merge_images_labels

class RunningVariable:
    ########################      General     #########################
    runID: int = -1 # run id
    roundID: int = -1 # epoch id
    ########################     Training     #########################
    device: None
    trainset: None # static
    testset: None # static
    train_loader: None # Reference
    test_loader: None 
    fg_optimizer: None
    all_models: None
    ########################       Stage       ########################
    stage: str = 's1'
    moe_gates: None # `mot/`ViTGating`
    moe_tracker: None # `mot/MoTTracker`

    def __init__(self):
        self.runID = 0

class ExpStatus:
    #########################   subClasses   ##########################
    args: argparse.Namespace = None
    ds: None # Class MOTDataset
    d: None # Class ExpData
    r: None # Class RunningVariable
    ########################## Name and Path ##########################
    path_exp_config: str = ""
    proj_codename: str = ""
    path_dataset_prefix: str = ""
    path_explog_prefix: str = ""
    path_explog_fullpath: str = ""
    path_expfolder: str = ""
    exp_header: str = ""
    ###########################    Flags    ###########################
    f1: bool = False
    f2: bool = False
    ##########################    States    ###########################
    rngs: Dict[str, Optional[np.random.Generator]] = {
        "primary": None,
        "class_order": None,
        "train_subset": None,
    }
    rngs_seed: Dict[str, int] = {}
    ########################## Configuration ##########################
    roundid_for_early_stop: int = -1
    should_save_final_model: bool = True

    ######################## End of Attributes ########################
    def __get_and_update_exp_id(self, path_exp_id):
        current_id = 0
        try:
            with open(path_exp_id, 'r') as f:
                data = yaml.safe_load(f)
                current_id = data.get('next_exp_id', 1)
        except FileNotFoundError:
            raise FileNotFoundError(f'Warning: config file under {path_exp_id} not found. Trying to create a new config file.')

        data["next_exp_id"] = current_id + 1
        with open(path_exp_id, 'w') as f:
            yaml.safe_dump(data, f)

        return current_id

    def __gen_exp_header(self, path_exp_id='config/exp_id.yaml'):
        exp_id = self.__get_and_update_exp_id(path_exp_id=path_exp_id)
        exp_header = datetime.now().strftime('%Y%m%d') + '_exp' + str(exp_id)
        
        return exp_header
    
    def rng_seeds_setup(self):
        logger = logging.getLogger()
        logger.info(f"SEED: setup all rngs (typically called at the beginning of exp).")
        self.rngs_seed['primary'] = self.args.random_seed
        self.rngs['primary'] = np.random.default_rng(self.rngs_seed['primary'])
        self.rngs_seed['class_order'] = self.rngs['primary'].integers(0, 2**32 - 1)
        self.rngs_seed['train_subset'] = self.rngs['primary'].integers(0, 2**32 - 1)
    
    def rng_states_init(self):
        logger = logging.getLogger()
        logger.info(f"SEED: (re)init all rngs (typically called before each run).")
        self.rngs['class_order'] = np.random.default_rng(self.rngs_seed['class_order'])
        self.rngs['train_subset'] = np.random.default_rng(self.rngs_seed['train_subset'])

    def random_states_init(self):
        logger = logging.getLogger()
        logger.info(f"SEED: reset all random states NOT managed by rngs.")
        np.random.seed(self.rngs_seed['primary'])
        torch.manual_seed(self.rngs_seed['primary'])

    def rng_seeds_print(self) -> None:
        logger = logging.getLogger()
        logger.info(f"[Seed setup]")
        for name, seed in self.rngs_seed.items():
            logger.info(f"\t{name:<20} {seed}")

    def __init__(self, args, path_exp_config='main/config/exp_global.yaml'):
        # Init basic instances
        self.args = args
        self.r = RunningVariable()
        self.path_exp_config = path_exp_config
        logger = logging.getLogger()
        # Load exp_config
        with open(self.path_exp_config, 'r') as f:
            data = yaml.safe_load(f)
            self.proj_codename = data['proj_codename']
            self.path_dataset_prefix = data['path_dataset_prefix']
            self.path_explog_prefix = data['path_explog_prefix']
            self.roundid_for_early_stop = data['roundid_for_early_stop']
            self.should_save_final_model = data['should_save_final_model']
            assert isinstance(self.should_save_final_model, bool), f"Expected bool, got {type(self.should_save_final_model)}"

        if self.args.override_set_dataset_path != 'none':
            self.path_dataset_prefix = self.args.override_set_dataset_path
        if not os.path.exists(self.path_dataset_prefix):
            raise FileNotFoundError(f"Base dataset path not exist, exiting: {self.path_dataset_prefix}")
        # check and add a slash for ds basefolder
        if not self.path_dataset_prefix.endswith('/'):
            logger.warning(f"Basefolder fullpath not ending with (/): {self.path_dataset_prefix}")
            self.path_dataset_prefix += '/'
        if not os.path.exists(self.path_explog_prefix):
            raise FileNotFoundError(f"Base explog path not exist, exiting: {self.path_explog_prefix}")
        assert args.n_expert >= 1
        
        # Init serial id and result folder
        if not self.args.should_use_fixed_header:
            self.exp_header = self.__gen_exp_header(path_exp_id='main/config/exp_local.yaml')
        else:
            self.exp_header = self.args.serial_header
        
        # Join output path
        _proj_path = os.path.join(self.path_explog_prefix, self.proj_codename)
        
        _proj_path, self.f1 = _check_create_folder(_proj_path)
        self.path_explog_fullpath, _ = _check_create_folder(os.path.join(_proj_path,'log'))
        self.path_expfolder, self.f2 = _check_create_folder(os.path.join(self.path_explog_fullpath, self.exp_header))

    def print_attributes(self):
        attributes = vars(self)
        logger = logging.getLogger()
        logger.info(f'[Experiment setting]')
        for key, value in attributes.items():
            logger.info(f"\t{key:<30}: {value}")

    def __dump_model_helper(self, path):
        checkpoint = {
            "experts": [model.state_dict() for model in self.r.all_models],
            "gating": self.r.moe_gates.state_dict(),
            "args": vars(self.args),
            "vit_param": self.ds.vit_param
        }
        torch.save(checkpoint, path)

    def dump_model(self, is_final=False):
        logger = logging.getLogger()
        if self.should_save_final_model and is_final:
            model_ckp_path = os.path.join(self.path_expfolder, 'models', f'run_{self.r.runID}_final.pt')
            os.makedirs(os.path.dirname(model_ckp_path), exist_ok=True)
            logger.info(f'Saving final model checkpoint file at {model_ckp_path}.')
            self.__dump_model_helper(model_ckp_path)
        else:
            logger.info('Skip storing model checkpoint.')

    def handle_override(self):
        logger = logging.getLogger()
        if not self.ds.basefolder_fullpath.endswith('/'):
            logger.warning(f"Basefolder fullpath not ending with (/): {self.ds.basefolder_fullpath}")
            self.ds.basefolder_fullpath += '/'

    def loader_init(self):
        logger = logging.getLogger()
        logger.info(f"Loading {len(self.d.Y_train_total)} images for training")
        _train_imgs = merge_images_labels(self.d.X_train_total, self.d.Y_train_total)
        
        self.r.trainset.imgs = self.r.trainset.samples = _train_imgs
        self.r.train_loader = DataLoader(
            self.r.trainset, 
            batch_size=self.ds.lr_param['TrBatchSz'], 
            shuffle=True, 
            num_workers=self.args.num_workers, 
            pin_memory=True)

        logger.info(f"Loading {len(self.d.Y_val_total)} images for testing")
        logger.info(f"Sample X/y: {self.d.X_val_total[0]}, {self.d.Y_val_total[0]}")
        _val_imgs = merge_images_labels(self.d.X_val_total, self.d.Y_val_total)
        self.r.testset.imgs = self.r.testset.samples = _val_imgs
        self.r.test_loader = DataLoader(
            self.r.testset, 
            batch_size=self.ds.lr_param['TsBatchSz'], 
            shuffle=False, 
            num_workers=self.args.num_workers, 
            pin_memory=True)
    
    def __set_trainable(self, net):
        for param in net.parameters():
            param.requires_grad = True
        if self.r.stage=='s1':
            for blk in net.encoder.blocks:
                if hasattr(blk, "attention"):
                    for param in blk.attention.parameters():
                        param.requires_grad = False
                if hasattr(blk, "layernorm_1"):
                    for param in blk.layernorm_1.parameters():
                        param.requires_grad = False
        elif self.r.stage=='s2':
            for param in net.classifier.parameters():
                param.requires_grad = False
            for blk in net.encoder.blocks:
                for param in blk.mlp.parameters():
                    param.requires_grad = False
                for param in blk.layernorm_2.parameters():
                    param.requires_grad = False
        elif self.r.stage=='s3':
            for blk in net.encoder.blocks:
                if hasattr(blk, "attention"):
                    for param in blk.attention.parameters():
                        param.requires_grad = False
                if hasattr(blk, "layernorm_1"):
                    for param in blk.layernorm_1.parameters():
                        param.requires_grad = False

    def __set_gating(self):
        if self.r.roundID >= self.ds.vit_param['fg_round']:
            for param in self.r.moe_gates.parameters():
                param.requires_grad = False

    def __set_optimizer(self):
        # @Behavior: called before training at beginning of new stage
        #   after `set_trainable` called'
        logger = logging.getLogger()
        _stage_lr = self.ds.lr_param['BaseLR']
        logger.info(f"STAGE: setting lr={_stage_lr}")
        
        _all_params = list(self.r.moe_gates.parameters()) + [p for e in self.r.all_models for p in e.parameters()]
        self.r.fg_optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, _all_params), 
            lr=_stage_lr, 
            weight_decay=self.ds.lr_param['WDecay'])
    
    def clip_gradients(self, max_norm=1.0):
        params = list(p for p in self.r.moe_gates.parameters() if p.requires_grad and p.grad is not None)
        for model in self.r.all_models:
            params += list(p for p in model.parameters() if p.requires_grad and p.grad is not None)
        if params:
            torch.nn.utils.clip_grad_norm_(params, max_norm)

    def stage_init(self):
        # @Behavior: called at the beginning of every round
        logger = logging.getLogger()
        _flag_stage_switch = False
        
        # Check current round location
        # Assume ds.t_param['stage1_last_round'] = 50
        # e.g., roundID= 0: first round
        #       roundID=49: last round of STAGE-I
        #       roundID=50: first round of STAGE-II
        if self.r.roundID < self.ds.t_param['stage1_last_round']:
            self.r.stage = 's1'
        elif self.r.roundID < self.ds.t_param['stage2_last_round']:
            self.r.stage = 's2'
        else:
            self.r.stage = 's3'
        
        if self.r.stage=='s1' and self.r.roundID==0:
            _flag_stage_switch = True
            logger.warning(f"STAGE: enter Stage I.")
        if self.r.stage=='s2' and self.r.roundID==self.ds.t_param['stage1_last_round']:
            _flag_stage_switch = True
            logger.warning(f"STAGE: enter Stage II.")
        if self.r.stage=='s3' and self.r.roundID==self.ds.t_param['stage2_last_round']:
            _flag_stage_switch = True
            logger.warning(f"STAGE: enter Stage III.")
        
        if _flag_stage_switch is True:
            logger.warning(f"STAGE: reset trainable and optimizer at stage boundary")
            for _m in self.r.all_models:
                self.__set_trainable(_m)
            self.__set_optimizer()
        self.__set_gating()

def _check_create_folder(path):
    _flag_path_exist = True
    if not os.path.exists(path):
        os.makedirs(path)
        _flag_path_exist = False
    return path, _flag_path_exist
