import copy
import os
import glob
import json
import random
import numpy as np
from collections import defaultdict, OrderedDict
from pathlib import Path
from misc.forked_pdb import ForkedPdb
import torch
from torch import Tensor

def get_num_classes(dataset: str) -> int:
    dataset_classes = {
        'ag_news': 4,
        'trec': 6,
        'gpt3mix/sst2': 2,
        'tweet_eval': 20,
        '20_newsgroups': 20,
        'multi_sent': 10
    }
    n_classes = dataset_classes.get(dataset)
    if n_classes is None:
        print(f'No {dataset} dataset in data directory')
        exit()
    return n_classes

def str2bool(v: str) -> bool:
    return v.lower() in ['true', 't']

def torch_save(base_dir: str, filename: str, data: Tensor):
    path = Path(base_dir)
    path.mkdir(parents=True, exist_ok=True)
    torch.save(data, path / filename)

def torch_load(base_dir: str, filename: str) -> Tensor:
    return torch.load(Path(base_dir) / filename, map_location=torch.device('cpu'))

def shuffle(seed: int, x: list, y: list) -> tuple[list, list]:
    idx = np.arange(len(x))
    np.random.seed(seed)
    np.random.shuffle(idx)
    return [x[i] for i in idx], [y[i] for i in idx]

def save(base_dir: str, filename: str, data: dict):
    path = Path(base_dir)
    path.mkdir(parents=True, exist_ok=True)
    with (path / filename).open('w+') as outfile:
        json.dump(data, outfile)

def exists(base_dir: str, filename: str) -> bool:
    return (Path(base_dir) / filename).exists()

def join_glob(base_dir: str, filename: str) -> list:
    return glob.glob(str(Path(base_dir) / filename))

def remove_if_exist(base_dir: str, filename: str):
    targets = join_glob(base_dir, filename)
    for t in targets:
        os.remove(t)



def get_state_dict(model) -> OrderedDict:
    state_dict = OrderedDict()
    for name, param in model.named_parameters():
        state_dict[name] = param.clone()
        if hasattr(param, 'weight_qtype'):
            state_dict[name + '_weight_qtype'] = param.weight_qtype
    return state_dict

def convert_tensor_to_np(state_dict: OrderedDict) -> OrderedDict:
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if isinstance(v, torch.Tensor):
            new_state_dict[k] = v.clone().detach().cpu().float().numpy() if v.dtype == torch.bfloat16 else v.clone().detach().cpu().numpy()
        else:
            new_state_dict[k] = v
    return new_state_dict

def convert_np_to_tensor(state_dict: OrderedDict, gpu_id: int, skip_stat=False, skip_mask=False, model=None) -> OrderedDict:
    _state_dict = OrderedDict()
    for k, v in state_dict.items():
        if skip_stat and ('running' in k or 'tracked' in k):
            _state_dict[k] = model[k]
        elif skip_mask and ('mask' in k or 'pre' in k or 'pos' in k):
            _state_dict[k] = model[k]
        else:
            _state_dict[k] = torch.tensor(v).cuda(gpu_id) if isinstance(v, (list, np.ndarray)) else v
    return _state_dict


def set_state_dict(model, state_dict: OrderedDict, gpu_id: int, params_to_update=None):
    model_state_dict = model.state_dict() 
    for name, param in state_dict.items():
        if params_to_update is None or name in params_to_update:
            if name in model_state_dict:
                model_state_dict[name].copy_(param)
                if name + '_weight_qtype' in state_dict:
                    model_state_dict[name].weight_qtype = state_dict[name + '_weight_qtype']
            else:
                adjusted_name = name.replace('base_model.model.', 'base_model.model.base_model.model.', 1)
                if adjusted_name in model_state_dict:
                    model_state_dict[adjusted_name].copy_(param)
                    if adjusted_name + '_weight_qtype' in state_dict:
                        model_state_dict[adjusted_name].weight_qtype = state_dict[adjusted_name + '_weight_qtype']
                else:
                    print(f"Warning: Parameter {name} not found in model state_dict.")
    model.load_state_dict(model_state_dict)

def convert_np_to_tensor_cpu(state_dict):
    _state_dict = OrderedDict()
    for k, v in state_dict.items():
        if isinstance(v, (list, np.ndarray)):
            _state_dict[k] = torch.tensor(v)
        else:
            _state_dict[k] = v
    return _state_dict