import torch
import os
import random
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from datasets import load_dataset, Dataset, concatenate_datasets, DatasetDict
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType, LoftQConfig
# from sklearn.datasets import fetch_20newsgroups
from peft import get_peft_model, LoraConfig, TaskType, PrefixTuningConfig, PromptTuningConfig, PromptEncoderConfig, AdaLoraConfig, AdaLoraModel
import torch.nn as nn
from transformers import BitsAndBytesConfig, QuantoConfig

def initialize_networks(model: str, n_classes: int = 20, adapter: str = 'lora', 
                        quantize: bool = False, random_quantize: bool = False, 
                        random_rank: bool = False, default_rank: int = 16, 
                        client_id: int = 0):

    # Since the server cannot connect to the internet, loading will be done from the local path.
    # We also provide a method to load the model directly from the internet.
    adapter = adapter.lower()
    model_name_or_path = model
    local_model_path = "your_local_path/distilbert-base-multilingual-cased"

    quantize_bits_list = []

    if random_quantize:
        quantize_bits = quantize_bits_list[client_id % len(quantize_bits_list)]
    else:
        quantize_bits = 4

    if random_rank:
        lora_rank = random.choice([16])
    else:
        lora_rank = default_rank

    dtype = torch.float32

    def load_model_with_fallback(model_name_or_path, local_model_path, n_classes, quantization_config=None):
        try:
            model_pre = AutoModelForSequenceClassification.from_pretrained(
                model_name_or_path,
                num_labels=n_classes,
                quantization_config=quantization_config,
                torch_dtype=dtype,
                low_cpu_mem_usage=True,
            )
        except Exception as e:
            print(f"Could not download model from Hugging Face Hub: {e}")
            if os.path.exists(local_model_path):
                try:
                    model_pre = AutoModelForSequenceClassification.from_pretrained(
                        local_model_path,
                        num_labels=n_classes,
                        quantization_config=quantization_config,
                        torch_dtype=dtype,
                        low_cpu_mem_usage=True,
                    )
                except Exception as e_local:
                    raise RuntimeError(f"Failed to load model from local path '{local_model_path}': {e_local}")
            else:
                raise RuntimeError(f"Local path '{local_model_path}' does not exist or does not contain necessary files. Model could not be downloaded and local model path is missing.")
        return model_pre
    
    if quantize:
        if quantize_bits == 4:
            quantization_config = QuantoConfig(weights="int4")
            model_pre = load_model_with_fallback(model_name_or_path, local_model_path, n_classes, quantization_config)
        elif quantize_bits == 8:
            quantization_config = QuantoConfig(weights="int8")
            model_pre = load_model_with_fallback(model_name_or_path, local_model_path, n_classes, quantization_config)
        elif quantize_bits == 2:
            quantization_config = QuantoConfig(weights="int2")
            model_pre = load_model_with_fallback(model_name_or_path, local_model_path, n_classes, quantization_config)
        else:
            raise ValueError(" only 2-bit, 4-bit and 8-bit")
    else:
        model_pre = load_model_with_fallback(model_name_or_path, local_model_path, n_classes)

    W1_state_dict = {k: v.cpu().clone() for k, v in model_pre.state_dict().items() if isinstance(v, torch.Tensor)}

    save_path = f"your_local_path"

    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    torch.save(W1_state_dict, save_path)
    
    if adapter == 'lora':
        peft_config = LoraConfig(
            r=lora_rank,
            lora_alpha=32,
            target_modules=["q_lin", "v_lin"] if quantize else ["query", "value"],
            lora_dropout=0.1,
            bias="none"
        )
        model = get_peft_model(model_pre, peft_config)
    elif adapter == 'prefix':
        peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_CLS)
        model = get_peft_model(model_pre, peft_config)
    elif adapter == 'prompt':
        peft_config = PromptTuningConfig(task_type=TaskType.SEQ_CLS)
        model = get_peft_model(model_pre, peft_config)
    elif adapter == 'p-tuning':
        peft_config = PromptEncoderConfig(
            task_type=TaskType.SEQ_CLS, num_virtual_tokens=20,
            token_dim=768, num_transformer_submodules=1,
            num_attention_heads=12, num_layers=12,
            encoder_reparameterization_type="MLP", encoder_hidden_size=768)
        model = get_peft_model(model_pre, peft_config)
    elif adapter == 'adalora':
        peft_config = AdaLoraConfig(
            task_type=TaskType.SEQ_CLS, r=lora_rank, lora_alpha=32,
            target_modules=["q_lin", "v_lin"] if "distilbert" in model_name_or_path else ["query", "value"],
            lora_dropout=0.01)
        model = AdaLoraModel(model_pre, peft_config, "default")
    elif adapter == 'hyper':
        from misc.transformer_model import HyperDistilBertForSequenceClassification, HyperTransformer
        base_plm = HyperDistilBertForSequenceClassification.from_pretrained(model_name_or_path, num_labels=n_classes)
        plm = HyperDistilBertForSequenceClassification.from_pretrained(model_name_or_path, num_labels=n_classes)
        plm.distilbert.transformer = HyperTransformer(config=plm.config, rank=lora_rank, con_dim=32)
        plm.load_state_dict(base_plm.state_dict(), strict=False)
        model = plm
    else:
        model = model_pre

    for name, param in model.named_parameters():
        if param.dtype in [torch.float16, torch.float32, torch.bfloat16]:
            param.requires_grad = True
        else:
            param.requires_grad = False

    model = model.cuda()

    return model, quantize_bits, lora_rank



