import copy 
import numpy as np
import pickle
import sys
import torch
import warnings
import wandb

from utils.arg_utils import get_args
from utils.model_utils import build_model, count_parameters, get_embedding_dim
from utils.lora_utils import add_adapters
from utils.data_utils import build_dataset
from utils.server_utils import fl_training, fl_training_het
from utils.init_utils import init_adapters

def main(): 
    args, out_str = get_args() 
    torch.manual_seed(args.seed)
    torch.set_printoptions(sci_mode=False, precision=7)
    warnings.filterwarnings('ignore')
    
    # Write to log file instead of system output
    if not args.verbose:  
        old_stdout = sys.stdout 
        file_name = "./logs/" + str(args.dataset) + "/" + out_str + ".log"
        log_file = open(file_name, "w", buffering=1)
        sys.stdout = log_file 
    
    if args.wandb_logging: 
        if args.wandb_project == '': 
            project_name = "Federated Ravan"
        else: 
            project_name = args.wandb_project
        wandb.init(
            project=project_name,
            name = "Dataset=" + str(args.dataset) + "_" + out_str,
            config={
                'seed': args.seed,
                'dataset': args.dataset,
                'clients': args.clients,
                'iid_alpha': args.iid_alpha,
                'method': args.adaptation_method,
                'rank': args.rank,
                'b_var': args.b_var,
                'a_var': args.a_var, 
                'r_var': args.r_var,
                'num_heads': args.num_heads,
                'clients_per_round': args.clients_round,
                'agg_method': args.aggregation_method,
                'optimizer': args.optimizer,
                'lr': args.client_lr,
                'local_epochs': args.epochs,
                'local_steps': args.local_steps, 
                'comm_rounds': args.comm_rounds, 
                'num_clients': args.clients, 
                'rank_het': args.het_ranks, 
                'rank_dist': args.het_dist, 
                'ranking': args.ranking,
                'init_scheme': args.init_scheme
            }
        )
    
    print("Running Federated Ravan Adaptation...")
    print(f"Using dataset: {args.dataset}")
    print(f"Hyperparameters: {out_str}")

    server_model, tokenizer = build_model(args.dataset)
    total_params, _ = count_parameters(server_model)
    model_name = server_model.config._name_or_path 
    print(f"Using model: {model_name}")
    server_model = server_model.cuda()

    # Build clients/validation dataset 
    clients, testloader = build_dataset(args.dataset, args.batch_size, args.clients, args.iid_alpha, args.seed)

    # Create a copy of the server model that does full fine-tuning for SB method 
    delta = None
    if args.adaptation_method == 'sb': 
        print("Performing a single round of fine-tuning for SB method")
        one_round_args = copy.deepcopy(args)
        one_round_args.comm_rounds = 1 
        one_round_args.local_steps = 1 
        one_round_args.adaptation_method = 'full_ft'

        sb_clients, sb_testloader = build_dataset(args.dataset, one_round_args.batch_size, args.clients, args.iid_alpha, args.seed)

        model_copy = copy.deepcopy(server_model)
        server_model.cpu()

        """ LLaMA Shortcut - Freeze every parameter except for the ones that are going to get adapters """
        if 'llama' in model_name: 
            for name, parameter in model_copy.named_parameters(): 
                if 'q_proj' not in name and 'v_proj' not in name: 
                    parameter.requires_grad = False 
        # fl_training(model_copy, clients, testloader, tokenizer, one_round_args)
        fl_training(model_copy, sb_clients, sb_testloader, tokenizer, one_round_args)
        delta = {} 
        server_model.cuda()
        for name, parameter in model_copy.named_parameters(): 
            delta[name] = model_copy.state_dict()[name].data - server_model.state_dict()[name].data
        
        del model_copy


    # Adds necessary parameters based on the specified adaptation method
    server_model = add_adapters(server_model, args.rank, args.alpha, args.b_var, args.r_var, args.a_var, args.num_heads, args.adaptation_method, delta=delta)
    if args.adaptation_method == 'vravan' or args.adaptation_method == 'ravan': 
        embed_dim = get_embedding_dim(server_model)
        init_adapters(server_model, args.init_scheme, embed_dim, args.rank, args.num_heads, args.b_var, args.a_var)
    if args.aggregation_method == 'FFA': 
        for name, parameter in server_model.named_parameters(): 
            if 'lora_A' in name: 
                parameter.requires_grad = False
    adapted_total, adapted_trainable = count_parameters(server_model)

    print()
    print("Trainable parameters:")
    for name, parameter in server_model.named_parameters(): 
        if parameter.requires_grad == True: 
            print(name)
    print(f"Original total trainable params: {total_params} | scaling of total params: {adapted_total/total_params} | scaling of trainable params: {adapted_trainable/total_params}")
    print()

    if args.het_ranks: 
        # Get rank sizes for each of the clients 
        if args.het_dist == 'uniform':
            proportions = [0.25, 0.25, 0.25, 0.25]
        elif args.het_dist == 'normal':
            proportions = [0.15, 0.35, 0.35, 0.15]
        elif args.het_dist == 'skewed_left':
            proportions = [0.10, 0.20, 0.30, 0.40]
        elif args.het_dist == 'skewed_right':
            proportions = [0.40, 0.30, 0.20, 0.10]
        else:
            raise ValueError("Unknown het_dist type.")
        rank_values = [0.25, 0.5, 0.75, 1]
        counts = [int(args.clients * p) for p in proportions]
        if sum(counts) < args.clients:
            diff = args.clients - sum(counts)
            counts[1] += diff  # add extra to 0.5, for example

        client_ranks = np.concatenate([np.repeat(val, count) for val, count in zip(rank_values, counts)])
        np.random.shuffle(client_ranks)

        client_ranks = client_ranks.tolist()
      
        fl_training_het(server_model, clients, testloader, tokenizer, client_ranks, args)
    else: 
        fl_training(server_model, clients, testloader, tokenizer, args)

    # Close log file 
    if not args.verbose: 
        sys.stdout = old_stdout
        log_file.close

if __name__ == '__main__': 
    main()