import argparse
import os
import pickle
import time
import json
# import GPUtil
from datasets import load_dataset, load_metric
import numpy as np
import torch
import transformers
from transformers import AutoTokenizer, TrainingArguments, \
    Trainer, AutoConfig, DataCollatorWithPadding,AutoModelForSequenceClassification

from typing import Dict
# from modeling_llama_qst import QSTLlamaForSequenceClassification, LlamaForSequenceClassification

import warnings
from torch.optim import Adam

from transformers import (
    set_seed,
    AutoConfig
)

from modeling_opt_CAST import CAST_OPTForSequenceClassification
from CASTConfig import CASTConfig

from sockconnect import setup_socket_client
from sockconnect import send_msg, recv_msg, get_socket
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from argparse import Namespace

def dirichlet_partition(targets, num_clients, num_classes, alpha, seed=None):
    if isinstance(targets, torch.Tensor):
        targets = targets.numpy()

    if seed is not None:
        np.random.seed(seed)

    num_samples = len(targets)
    client_indices = {i: [] for i in range(num_clients)}
    class_indices = [np.where(targets == i)[0] for i in range(num_classes)]
    class_distribution = np.random.dirichlet([alpha] * num_clients, num_classes)

    for k in range(num_classes):
        np.random.shuffle(class_indices[k])
        num_samples_per_client = (np.round(class_distribution[k] * len(class_indices[k]))).astype(int)
        diff = len(class_indices[k]) - np.sum(num_samples_per_client)
        if diff != 0:
            for i in range(abs(diff)):
                if diff > 0:
                    client_id = np.random.choice(np.where(num_samples_per_client > 0)[0])
                    num_samples_per_client[client_id] += 1
                else:
                    client_id = np.random.choice(np.where(num_samples_per_client > 0)[0])
                    num_samples_per_client[client_id] -= 1
        while len(class_indices[k]) != np.sum(num_samples_per_client):
            num_samples_per_client[np.argmax(num_samples_per_client)] -= (np.sum(num_samples_per_client) - len(class_indices[k]))

        start_idx = 0
        for client_id in range(num_clients):
            end_idx = start_idx + num_samples_per_client[client_id]
            client_indices[client_id].extend(class_indices[k][start_idx:end_idx])
            start_idx = end_idx

    for client_id in client_indices:
        client_indices[client_id] = np.array(client_indices[client_id])

    return client_indices


def smart_tokenizer_and_embedding_resize(
        special_tokens_dict: Dict,
        tokenizer: transformers.PreTrainedTokenizer,
        model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings_data = model.get_input_embeddings().weight.data
        output_embeddings_data = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
        output_embeddings_data[-num_new_tokens:] = output_embeddings_avg



task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
}



def train(args, task):

    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    LLM = AutoModelForSequenceClassification.from_pretrained(               
        args.model_checkpoint, 
        num_labels=args.num_labels,
        torch_dtype=torch.float16                                           
    ).to(device)

    for param in LLM.parameters():                                       
        param.requires_grad = False

    tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, use_fast=True, max_length=args.max_seqlen)

    if tokenizer._pad_token is None:
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
            tokenizer=tokenizer,
            model=LLM,
        )

    actual_task = "mnli" if task == "mnli-mm" else task
    print(f"Loading dataset for task: {actual_task}")         
    dataset = load_dataset("glue", task)
    metric = load_metric('glue', task)
    print(dataset)

    sentence1_key, sentence2_key = task_to_keys[task]

    def preprocess_function(examples):
        if sentence2_key is None:
            return tokenizer(
                examples[sentence1_key], 
                truncation=True, 
                padding="max_length",                                    
                max_length=args.pad_seqlen                                 
            )
        return tokenizer(
            examples[sentence1_key], 
            examples[sentence2_key], 
            truncation=True, 
            padding="max_length",                                          
            max_length=args.pad_seqlen                                    
        )

    encoded_dataset = dataset.map(preprocess_function, batched=True)

    validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation"
    num_samples = len(encoded_dataset[validation_key])
    num_batches = num_samples // args.batch_size
    valid_samples = num_batches * args.batch_size

    num_train_samples = len(encoded_dataset["train"])
    epoch_steps = num_train_samples // args.batch_size
    
    args.myeval_step = num_batches 

    encoded_dataset[validation_key] = encoded_dataset[validation_key].select(range(valid_samples))

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    config = AutoConfig.from_pretrained(args.model_checkpoint)

    CAST_config = CASTConfig(
        CAST_add_layer_norm_before_adapter=False,
        CAST_add_layer_norm_after_adapter=True,
        CAST_activation = args.CAST_activation,
        CAST_hidden_size = args.CAST_hidden_size,
        CAST_dropout = args.CAST_dropout,
        )  

        
    if args.my_modelname == "CAST":
        model = CAST_OPTForSequenceClassification(LLM, config, CAST_config).to(device)

        for name, module in model.named_modules():
            module.to(torch.float16)

 
    print(model)
    for name, param in model.named_parameters():
        print(f"Name: {name}, Shape: {param.shape}, Dtype: {param.dtype}, Trainable: {param.requires_grad}, Device: {param.device}")

    encoded_dataset = encoded_dataset.remove_columns(["sentence", "idx"])   

    train_labels = encoded_dataset["train"]["label"]

    client_partitions  = dirichlet_partition(np.array(train_labels), args.num_clients, args.num_labels, args.dirichlet_alpha, seed=args.seed)
    selected_indices = client_partitions[args.client_id]
    client_train_dataset = encoded_dataset["train"].select(selected_indices.tolist())

    train_dataloader = DataLoader(                                          
        client_train_dataset,                                           
        shuffle=True,                                                       
        batch_size=args.batch_size,                                         
        collate_fn=data_collator                                            
    )

    validation_dataloader = DataLoader(
        encoded_dataset["validation"],                                      
        batch_size=args.batch_size,          
        collate_fn=data_collator        
    )
                                          
    setup_socket_client()                                             
    ss = get_socket()
    send_msg(ss, args)
    print("Model configuration sent!")
                            
    success = recv_msg(ss)                                           
    print(success)
    if success is True:
        print("Successfully connected to the server")
    else:
        print("Receiving end confirmation failed")


    x = args.mytrain_onetime_step                                       
    y = args.myeval_step                                               
    n = args.mytrain_looptime                                          
    cycle_count = 0
    step_count = 0                                                 

    while cycle_count < n:
        print(f"\n=== Cycle {cycle_count + 1}/{n} ===")

        train_step_count = 0                
        model.eval()                                                    
        with torch.no_grad():
            while train_step_count < x:
                for batch in train_dataloader:
                    batch = {key: val.to(device) for key, val in batch.items()}     
                    
                    outputs = model(**batch)                                       
                    train_step_count += 1
                    step_count += 1
                    print(f"Train Step {train_step_count}/{x}, Cycle {cycle_count + 1}/{n}")   

                    if train_step_count >= x:                                       
                        break

        val_step_count = 0
        with torch.no_grad():
            while val_step_count < y:
                for batch in validation_dataloader:
                    batch = {key: val.to(device) for key, val in batch.items()}
                    outputs = model(**batch)

                    val_step_count += 1
                    step_count += 1
                    print(f"Accuracy validation Step {val_step_count}/{y}, Cycle {cycle_count + 1}/{n}")

                    if val_step_count >= y:
                        break

        cycle_count += 1                                                            



# GLUE_TASKS = ["cola", "mnli", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb"]
GLUE_TASKS = [ "sst2"]
TRAIN_MODEL = ["CAST"]
DEFAULT_PAD_TOKEN = "[PAD]"


if __name__ == "__main__":

    args = Namespace()
    args.seed = 0
    args.my_modelname = "CAST"                          
    args.model_checkpoint = "facebook/opt-1.3b"         
    # args.model_checkpoint = "facebook/opt-350m"
    args.num_labels = 2                                
    args.batch_size = 8              
    args.max_seqlen = 512                               
    args.pad_seqlen = 128                               
    args.CAST_activation = "swish"                     
    args.CAST_hidden_size = 64                        
    args.CAST_dropout = 0.1       
    args.mytrain_onetime_step = 100                     
    args.mytrain_looptime = 1                           
    args.mytrain_log_step = 10                          
    args.myeval_step = 54
    args.dirichlet_alpha = 0.5
    args.num_clients = 100
    args.client_id = 0                              


    for mymodel in TRAIN_MODEL:
        args.my_modelname = mymodel

        result_dict = {}
        for task in GLUE_TASKS:

            args.num_labels = 3 if task.startswith("mnli") else 1 if task == "stsb" else 2

            result_dict[task] = {}

            train(args, task)                         



