import copy
import os

import torch
import torch.nn as nn
from typing import *
from torch.utils.data import DataLoader
from transformers import (AutoModelForCausalLM, 
                          AutoTokenizer,
                          )

import federated
import utils


class FederatedController:
    def __init__(self, 
                 server_model: Union[nn.Module, AutoModelForCausalLM],
                 client_loaders: Dict[int, DataLoader],
                 eval_loader: Optional[Dict[str, DataLoader]] = None,
                 tokenizer: Optional[AutoTokenizer] = None,
                 config: Dict = None,
                 ):

        self.server_model = server_model
        self.client_loaders = client_loaders
        self.eval_loader = eval_loader
        self.config = config
        self.tokenizer = tokenizer
        self.is_llm = ('huggingface' in self.config.data.name)
        print("[DEBUG] Number of clients:", len(self.client_loaders))

        self.init_server_model = copy.deepcopy(self.server_model) 
        self.log_path = os.path.join(self.config.save_dir, "stats.json")
        if os.path.exists(self.log_path):
            print(f"Found existing results in {self.log_path}. Will delete it.")
            os.remove(self.log_path)
        self.device = self.config.device
    
    def train(self, num_rounds):
        self.config.federated.num_rounds = num_rounds
        self.evaluate(self.server_model, verbose=True)
        self.server_model = federated.train_fed(self.server_model, 
                                                self.client_loaders, 
                                                self.config,
                                                eval_fn=self.evaluate, 
                                                device=self.device,
                                                is_llm=self.is_llm,
                                                tokenizer=self.tokenizer,
                                                )

    
    def leave(self, unlearned_client):
        # sanity check to-leave client presents in the coalition
        assert unlearned_client in self.client_loaders.keys()
    
    def join(self, client_id, dataloader, attach_to=None):
        # sanity check to-join client doesn't present in the coalition
        assert client_id not in self.client_loaders.keys()
        self.client_loaders[client_id] = dataloader
        
    def prepare_evaluation(self):
        if self.config.federated.evaluation_strategy == "fixed":
            pass
        elif self.config.federated.evaluation_strategy == "remaining":
            # update the eval set from the remaining clients
            # TODO: incompatible to LLM experiments
            self.eval_loaders = federated.merge_client_loaders(self.client_loaders,
                                                               self.config.local.train_batch_size,
                                                               self.config.local.eval_batch_size)
        else:
            raise NotImplementedError
    
    def evaluate(self, server_model, verbose=False):
        self.prepare_evaluation()

        if self.is_llm:
            train_acc = utils.check_accuracy_llm(server_model, 
                                                 self.eval_loader["train"], 
                                                 tokenizer=self.tokenizer, 
                                                 device=self.device,
                                                 )
            test_acc = utils.check_accuracy_llm(server_model, 
                                                self.eval_loader["test"], 
                                                tokenizer=self.tokenizer, 
                                                device=self.device,
                                                )
            train_loss = utils.check_loss_llm(server_model, 
                                              self.eval_loader["train"], 
                                              tokenizer=self.tokenizer, 
                                              device=self.device,
                                              )
            test_loss = utils.check_loss_llm(server_model, 
                                             self.eval_loader["test"],
                                             tokenizer=self.tokenizer, 
                                             device=self.device,
                                             )
        else:
            train_acc = utils.check_accuracy(server_model, 
                                             self.eval_loader["train"], 
                                             device=self.device,
                                             )
            test_acc = utils.check_accuracy(server_model, 
                                            self.eval_loader["test"], 
                                            device=self.device,
                                            )

            compute_loss = getattr(torch.nn, self.config.local.loss)
            compute_loss = compute_loss(**utils.get_kwargs(compute_loss, self.config.local))
            train_loss = utils.check_loss(server_model, 
                                          self.eval_loader["train"], 
                                          compute_loss=compute_loss, 
                                          device=self.device,
                                          )
            test_loss = utils.check_loss(server_model, 
                                         self.eval_loader["test"], 
                                         compute_loss=compute_loss, 
                                         device=self.device,
                                         )
        metrics = {
            "train_acc": train_acc,
            "train_loss": train_loss,
            "test_acc": test_acc,
            "test_loss": test_loss,
        }
        utils.log_stats(self.log_path, **metrics)
        if verbose:
            print("eval:", metrics)
        return metrics