import torch
import torch.nn as nn

from concurrent import futures

import pickle
import math
import sys
from transformers.activations import get_activation
from dataclasses import dataclass

from torch.optim import Adam
from transformers import PretrainedConfig
from transformers import (
    set_seed,
    AutoConfig
)
import socket
import time
import struct
from torch.nn import CrossEntropyLoss

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
)
from typing import Optional, Tuple
from sklearn.metrics import accuracy_score
from tqdm import tqdm

@dataclass
class CASTOutput(BaseModelOutputWithPast):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None


class Activations(nn.Module):
    def __init__(self, activation_type):
        super().__init__()
        self.f = get_activation(activation_type)

    def forward(self, x):
        return self.f(x)


def send_msg(sock, msg):
    msg = pickle.dumps(msg)
    l_send = len(msg)
    msg = struct.pack('>I', l_send) + msg
    sock.sendall(msg)
    return l_send

def recv_msg(sock):
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    return msg, msglen

def recvall(sock, n):
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

class StandardAdapter(nn.Module):
    def __init__(self, 
                 in_features: int, 
                 adapter_size: int, 
                 activation= None, 
                 add_layer_norm_before_adapter=False,
                 add_layer_norm_after_adapter=False, 
                 dropout=0.0, 
                 bias=False):

        super(StandardAdapter, self).__init__()
        self.down_proj = nn.Linear(in_features, adapter_size, bias=bias)

        if activation is not None:
            self.activation = Activations(activation.lower())
        else:
            self.activation = None

        self.up_proj = nn.Linear(adapter_size, in_features, bias=bias)

        self.add_layer_norm_before = add_layer_norm_before_adapter
        self.add_layer_norm_after = add_layer_norm_after_adapter
        if self.add_layer_norm_before:
            self.pre_layer_norm = nn.LayerNorm(in_features)
        if self.add_layer_norm_after:
            self.post_layer_norm = nn.LayerNorm(in_features)

        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
  
        
    def forward(self, x):

        residual = x

        if self.add_layer_norm_before:
            x = self.pre_layer_norm(x)

        x = self.down_proj(x)
        x = self.activation(x)
        x = self.up_proj(x)

        x = self.dropout(x)

        if self.add_layer_norm_after:
            x = self.post_layer_norm(x)

        return residual + x
    

@dataclass
class CASTConfig(PretrainedConfig):
    CAST_add_layer_norm_before_adapter: bool = False
    CAST_add_layer_norm_after_adapter: bool = True
    CAST_activation: str = ""
    CAST_hidden_size: int = 16
    CAST_dropout: float = 0.0
    CAST_num_labels: int = 2


class MiddleLayers(nn.Module):
    def __init__(self, config, CASTConfig):
        super(MiddleLayers, self).__init__()
        self.num_layers = config.num_hidden_layers
        self.num_labels = CAST_config.CAST_num_labels

        self.CAST_layers = nn.ModuleList([
            StandardAdapter(
                in_features=int(config.hidden_size),
                adapter_size=int(CASTConfig.CAST_hidden_size),
                activation=CASTConfig.CAST_activation,
                add_layer_norm_after_adapter=CASTConfig.CAST_add_layer_norm_after_adapter,
                add_layer_norm_before_adapter=CASTConfig.CAST_add_layer_norm_before_adapter,
                dropout=CASTConfig.CAST_dropout
            ) for i in range(config.num_hidden_layers)
        ])

        self.z = nn.ParameterList([
            nn.Parameter(torch.tensor([0.5])) for i in range(config.num_hidden_layers)
        ])

        if config.word_embed_proj_dim != config.hidden_size:
            self.upsample = nn.Linear(int(config.hidden_size), config.word_embed_proj_dim)

        else:
            self.upsample = None


        self.lm_head_z = nn.Parameter(torch.zeros(config.word_embed_proj_dim))

        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm_CAST = nn.LayerNorm(
                int(config.hidden_size), 
                elementwise_affine=config.layer_norm_elementwise_affine
            )
        else:
            self.final_layer_norm_CAST = None

        self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels , bias=False)


    def forward(self, activations, project_out, labels):

        CAST_hidden_states = activations[0]

        for idx in range(self.num_layers):
            z = torch.sigmoid(self.z[idx])
            CAST_hidden_states = (1 - z) * activations[idx] + z * CAST_hidden_states
            CAST_hidden_states = self.CAST_layers[idx](CAST_hidden_states)
        if self.final_layer_norm_CAST is not None:
            CAST_hidden_states = self.final_layer_norm_CAST(CAST_hidden_states)
        if self.upsample is not None:
            CAST_hidden_states = self.upsample(CAST_hidden_states)

        lm_head_z = torch.sigmoid(self.lm_head_z)
        final_hidden_states = lm_head_z * CAST_hidden_states + (1 - lm_head_z) * project_out
        logits = self.score(final_hidden_states)

        loss = None

        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return CASTOutput(
            loss=loss,
            logits=logits,
        )

def data_forward():

    optimizer.zero_grad() 
    msg, datasize = recv_msg(clientsoclist[0])

    total_receivesize_list.append(datasize)
    client_receivesize_list[0].append(datasize)
    train_receivesize_list.append(datasize)

    backbone_hidden_states = msg['backbone_hidden_states']
    project_out = msg['projectout']
    labels = msg['labels']

    outputs = model(backbone_hidden_states, project_out, labels)
    return outputs, labels


users = 1 
host = 'xxx'
port = 'xxx'

s = socket.socket()
s.bind((host, port))
s.listen(5)

clientsoclist = []
train_total_batch = []
val_acc = []
total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []


for i in range(users):
    print("Waiting for client connection")
    conn, addr = s.accept()
    print('Conntected with', addr)
    clientsoclist.append(conn)    # append client socket on list

args, datasize = recv_msg(clientsoclist[0])
print("Received client model information")

send_msg(clientsoclist[0], True)  

seed = args.seed
model_name_or_path = args.model_checkpoint

set_seed(seed)
config = AutoConfig.from_pretrained(args.model_checkpoint, num_labels=args.num_labels)

CAST_config = CASTConfig(
        CAST_add_layer_norm_before_adapter=False,
        CAST_add_layer_norm_after_adapter=True,
        CAST_activation=args.CAST_activation,
        CAST_hidden_size=args.CAST_hidden_size,
        CAST_dropout=args.CAST_dropout,
        CAST_num_labels = args.num_labels
        )      

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = MiddleLayers(config, CAST_config).to(device)

optimizer = Adam(model.parameters(), lr=args.learning_rate, eps=1e-4)

for name, module in model.named_modules():
        module.to(torch.float16)


x = args.mytrain_onetime_step  
y = args.myeval_step  
n = args.mytrain_looptime  
cycle_count = 0
step_count = 0  


while cycle_count < n:
    print(f"\n=== Cycle {cycle_count + 1}/{n} ===")

    train_step_count = 0
    model.train()  

    with tqdm(total=x, desc="Training Progress", dynamic_ncols=True) as pbar:
        while train_step_count < x:
            optimizer.zero_grad()
            outputs, labels = data_forward()
            loss = outputs.loss

            loss.backward()
            optimizer.step()

            train_step_count += 1
            step_count += 1

            pbar.set_postfix({"Loss": loss.item()})
            pbar.update(1)

            if train_step_count >= x:
                break

    val_step_count = 0
    model.eval()  
    all_predictions = []
    all_labels = []
    with torch.no_grad():
        with tqdm(total=y, desc="Accuracy Validation Progress") as pbar:
            while val_step_count < y:
                outputs, labels = data_forward()
                logits = outputs.logits
                predictions = torch.argmax(logits, dim=-1) 
                all_predictions.extend(predictions.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                val_step_count += 1
                step_count += 1

                pbar.update(1)
   
                if val_step_count >= y:
                    break

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f"Cycle {cycle_count + 1}:  Validation Accuracy = {accuracy:.4f}")

    cycle_count += 1





