import os
import sys
import json
import time
import torch
import random
import numpy as np

from omegaconf import OmegaConf
from rich.console import Console
from argparse import ArgumentParser
from collections import defaultdict
from transformers import TrainingArguments
from typing import Any, Dict, List, Optional, Tuple, Union


console = Console()

def _str2bool(text: str) -> bool:
    if text.lower() in ['true', 'yes']:
        return True
    else:
        return False


class AttrDict(dict):
    __setattr__ = dict.__setitem__

    def __getattribute__(self, item):
        if item in self:
            return self[item]
        else:
            return super().__getattribute__(item)

    @classmethod
    def from_nested_dicts(cls, data):
        if not isinstance(data, dict):
            return data
        else:
            return cls({key: cls.from_nested_dicts(data[key]) for key in data})

    def update_missing(self, data):
        for key in data:
            if key not in self:
                self[key] = data[key]

    def copy(self):
        return AttrDict(self)


class TimingContext:
    def __init__(self, timer, key, additive=False, average=None):
        self._timer = timer
        self._key = key
        self._additive = additive
        self._average = average
        self._time_enter = None

    def __enter__(self):
        self._time_enter = time.time()

    def __exit__(self, type_, value, traceback):
        if self._key not in self._timer:
            if self._average is not None:
                self._timer[self._key] = AvgTime(num_values_to_avg=self._average)
            else:
                self._timer[self._key] = 0

        time_passed = max(time.time() - self._time_enter, 1e-8)  # EPS to prevent div by zero

        if self._additive:
            self._timer[self._key] += time_passed
        elif self._average is not None:
            self._timer[self._key].values.append(time_passed)
        else:
            self._timer[self._key] = time_passed


class Timing(AttrDict):
    def timeit(self, key):
        return TimingContext(self, key)

    def add_time(self, key):
        return TimingContext(self, key, additive=True)

    def time_avg(self, key, average=10):
        return TimingContext(self, key, average=average)

    def __str__(self):
        s = ''
        i = 0
        for key, value in self.items():
            str_value = f'{value:.4f}' if isinstance(value, float) else str(value)
            s += f'{key}: {str_value}'
            if i < len(self) - 1:
                s += ', '
            i += 1
        return s


def get_default_args():
    parser = ArgumentParser()
    # parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--config_path", type=str, default='configs/train_multitask_3b.yaml')
    parser.add_argument("--set_to_reference", action="store_true")
    parser.add_argument("--use_proxy", action="store_true", help="")
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--world_size", type=int, default=1)

    parser.add_argument("--device_id", type=str, default="cuda:0")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument(
        "--datapath",
        default="",
        help="path to datasets, defaults to {parlai_dir}/data",
    )
    parser.add_argument(
        "--blenderbot3-base-agent",
        type=str,
        default="r2c2",
        choices=["r2c2", "opt"],
        help="Base agent for which to format tasks.",
    )
    parser.add_argument(
        "--hf_tokenizer_path",
        default="hexa/models/tokenizer",
    )
    parser.add_argument(
        "--hf_model_config_path",
        default="hexa/models/config/bb3_3b.json",
    )
    parser.add_argument(
        "--hf_trainer_args_path",
        default="configs/bb3_train.yaml",
    )
    parser.add_argument('--server_port', type=int, default=-1)
    parser.add_argument('--deploy_port', type=int, default=-1)
    parser.add_argument('--debug', type=str, default='true')
    parser.add_argument('--knowledge_conditioning', type=str, default='combined',
                        choices=['combined', 'separate', 'both'])
    parser.add_argument('--memory_decision', type=str, default='compute', choices=['always', 'never', 'compute'])
    parser.add_argument('--save_chat', type=str, default='true')
    parser.add_argument('--chat_save_dir', type=str, default='chat_log')

    args, _ = parser.parse_known_args()
    # args.debug = _str2bool(args.debug)
    
    return args

def save_cfg(args, save_dir, fname):
    os.makedirs(save_dir, exist_ok=True)
    torch.save(args, os.path.join(save_dir, fname))

def cfg_override_from_file(args):
    cfg = OmegaConf.create(vars(args))
    c = OmegaConf.load(cfg.config_path)
    cfg = OmegaConf.merge(cfg, c)
    return cfg

def cfg_override_from_cli(cfg):
    cli_args = clean_arguments(sys.argv[1:])
    for k, v in cli_args.items():
        h_key = k.split('.')
        if len(h_key) > 1:
            if h_key[0] in cfg.keys():
                if h_key[1] in cfg[h_key[0]]:
                    # print(k)
                    if type(cfg[h_key[0]][h_key[1]]) == bool:
                        setattr(cfg[h_key[0]], h_key[1], v[0].lower().capitalize() == "True")
                    else:
                        setattr(cfg[h_key[0]], h_key[1], type(cfg[h_key[0]][h_key[1]])(v[0]))
        else:
            if k in cfg.keys():
                if type(cfg[k]) == bool:
                    setattr(cfg, k, v[0].lower().capitalize() == "True")
                else:
                    setattr(cfg, k, type(cfg[k])(v[0]))
    return cfg

def cfg_override_from_hf(cfg, remove_unused_columns=False):
    c = OmegaConf.load(cfg.hf_trainer_args_path)
    extra_args = {}
    v = c['trainer'].pop('decay_forbidden_layer_types', None)
    extra_args['decay_forbidden_layer_types'] = v
    v = c['trainer'].pop('fp16_safe', None)
    extra_args['fp16_safe'] = v
    cfg = TrainingArguments(**c['trainer'], remove_unused_columns=remove_unused_columns)
    return cfg, extra_args

def hf_trainer_args_from_opt(opt, remove_unused_columns=False):
    extra_args = {}
    v = opt['trainer'].pop('decay_forbidden_layer_types', None)
    extra_args['decay_forbidden_layer_types'] = v
    v = opt['trainer'].pop('fp16_safe', None)
    extra_args['fp16_safe'] = v
    cfg = TrainingArguments(**opt['trainer'], remove_unused_columns=remove_unused_columns)
    return cfg, extra_args

def clean_arguments(args):
    ret_args = defaultdict(list)

    for index, k in enumerate(args):
        if index < len(args) - 1:
            a, b = k, args[index+1]
        else:
            a, b = k, None

        new_key = None

        # double hyphen, equals
        if a.startswith('--') and '=' in a:
            new_key, val = a.split('=')

        # double hyphen, no arg
        # single hyphen, no arg
        elif (a.startswith('--') and (not b or b.startswith('-'))) or \
                (a.startswith('-') and (not b or b.startswith('-'))):
            val = True

        # double hypen, arg
        elif a.startswith('--') and b and not b.startswith('-'):
            val = b

        # single hypen, arg
        elif a.startswith('-') and b and not b.startswith('-'):
            val = b

        else:
            if (b is None) or (a == val):
                continue

            else:
                raise ValueError('Unexpected argument pair: %s, %s' % (a, b))

        # sanitize the key
        key = (new_key or a).strip(' -')
        ret_args[key].append(val)

    return ret_args

def print_config(config, opt=None, text="configuration", color="bold green", convert=False):
    save_config_path = os.path.join(config.output_dir, 'opt.json')
    os.makedirs(config.output_dir, exist_ok=True)
    if opt.local_rank in [-1,0]:
        with open(save_config_path, 'w') as f:
            json.dump(opt, f)
        console.print(f"\n\n**************** {text.upper()} ****************", style=color)
        if convert:
            for key, val in sorted(vars(config).items()):
                keystr = "{}".format(key) + (" " * (30 - len(key)))
                console.print(f"{keystr} -->   {val}")       
        else:
            for key in list(sorted(config.keys())):
                keystr = "{}".format(key) + (" " * (30 - len(key)))
                console.print(f"{keystr} -->   {config[key]}")        
        console.print(f"**************** {text.upper()} ****************\n\n", style=color)
        
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

def build_config():
    args = get_default_args()
    opt = cfg_override_from_file(args)
    opt = cfg_override_from_cli(opt)
    opt = OmegaConf.to_container(opt)
    opt = AttrDict.from_nested_dicts(opt)
    
    opt.debug = _str2bool(opt.debug)
    opt.save_chat = _str2bool(opt.save_chat)

    return opt