import torch
from config import parse_args

import torchvision.models as models
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm import tqdm
import random
import numpy as np
from collections import defaultdict
import copy

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = 'cuda'
args = parse_args()

tokenizer = CLIPTokenizer.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)

def tokenize_captions(examples, is_train=False):
    captions = []
    for caption in examples:
        if isinstance(caption, str):
            captions.append(caption)
        elif isinstance(caption, (list, np.ndarray)):
            captions.append(random.choice(caption) if is_train else caption[0])
        else:
            raise ValueError(
                f"Caption column `{caption_column}` should contain either strings or lists of strings."
            )
    inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
    input_ids = inputs.input_ids
    return input_ids
        
def collate_fn(examples):
    pixel_values = torch.stack([example[0] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = [example[1] for example in examples]
    padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")
    domain_ids = torch.tensor([example[2] for example in examples])
    class_ids = torch.tensor([example[3] for example in examples])
    return {
        "pixel_values": pixel_values,
        "input_ids": padded_tokens.input_ids,
        "attention_mask": padded_tokens.attention_mask,
        "domain_ids": domain_ids,
        "class_ids": class_ids,
    }

def train(seed, train_setting):
    print(f"Training with seed {seed} and setting {train_setting}")
    setup_seed(seed)

    trainloaders = []
    if args.dataset == 'domainnet':
        from domainnet_data import get_dataloader, get_dataloader_domain
    elif args.dataset == 'pacs':
        from pacs_data import get_dataloader, get_dataloader_domain
    elif args.dataset == 'officehome':
        from officehome_data import get_dataloader, get_dataloader_domain
    elif args.dataset == 'bloodmnist':
        from bloodmnist_data import get_dataloader, get_dataloader_domain
    elif args.dataset == 'dermamnist':
        from dermamnist_data import get_dataloader, get_dataloader_domain
    elif args.dataset == 'ucm':
        from ucm_data import get_dataloader, get_dataloader_domain

    for domain in args.domains:
        for i in range(args.client_num):
            trainloader = get_dataloader_domain(
                    args, args.train_batch_size, None,
                    train_setting, domain, tokenize_captions,  
                    collate_fn, num_shot=args.num_shot,
                    client_id=i)            
            trainloaders.append(trainloader)

    if args.dataset=='pacs':
        num_shot_test = 32
    elif args.dataset=='ucm':
        num_shot_test = 8
    else:
        num_shot_test = -1
    test_dataloader = get_dataloader(
            args, args.train_batch_size, None,
            'test', tokenize_captions,  
            collate_fn, num_shot=num_shot_test)  
            
    num_epochs = 50
    num_local_epochs = 1
    num_classes = len(args.categories)

    server_model = models.resnet18(pretrained=args.pretrained)
    server_model.fc = torch.nn.Linear(server_model.fc.in_features, num_classes)
    server_model.to(device)
    state_dict = None

    for epoch in range(num_epochs):

        state_dict_list = []
        for train_dataloader in trainloaders:
            client_model = copy.deepcopy(server_model)

            optimizer = torch.optim.SGD(client_model.parameters(), momentum=0.9, lr=0.01)
            criterion = torch.nn.CrossEntropyLoss()
            client_model.to(device)
            client_model.train()
            
            for _ in range(num_local_epochs):
                # local update
                for batch in tqdm(train_dataloader):
                    optimizer.zero_grad()
                    inputs = batch["input_ids"].to(device)
                    outputs = client_model(batch['pixel_values'].to(device))
                    labels = batch['class_ids'].to(device)
                    # Compute loss and perform backpropagation
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()
                    # scheduler.step()

            state_dict_list.append(client_model.state_dict())
            del client_model
            
        state_dict = average_model_weights(state_dict_list)
        server_model.load_state_dict(state_dict)

        # for eval
        if (epoch+1) % 5 == 0 or epoch == num_epochs - 1:
            server_model.eval()
            d_count = defaultdict(list)
            with torch.no_grad():
                total_correct = 0
                total_samples = 0
                for batch in test_dataloader:
                    inputs = batch["pixel_values"].to(device)
                    outputs = server_model(inputs)
                    labels = batch['class_ids'].to(device)
                    preds = torch.argmax(outputs, dim=1)
                    # Compute loss and accuracy
                    total_correct += preds.eq(labels).sum().item()
                    total_samples += inputs.size(0)
                    for i, did in enumerate(batch['domain_ids']):
                        d_count[did.item()].append(labels[i].item() == preds[i].item())

            tot_acc = 0.
            for k in d_count.keys():
                acc = sum(d_count[k]) / len(d_count[k]) * 100 
                tot_acc += acc
                print(f"{args.domains[k]}: {round(acc, 3)}", end=", ")

            if args.dataset in ['ucm', 'dermamnist']:
                print(f"{train_setting}/Epoch {epoch+1}: Accuracy = {round(total_correct/total_samples*100, 3)}")
            else:
                print(f"{train_setting}/Epoch {epoch+1}: Accuracy = {round(tot_acc/len(d_count.keys()), 3)}")

def average_model_weights(state_dict_list):
    avg_state_dict = {}
    for key in state_dict_list[0].keys():
        avg_state_dict[key] = sum([state_dict[key] for state_dict in state_dict_list])/len(state_dict_list)
    return avg_state_dict

if __name__=="__main__":
    if args.dataset in ['ucm', 'dermamnist']:
        for seed in [0, 1, 2]:
            for train_setting in ['train_niid001', 'train', 'train_niid05']:
                train(seed, train_setting)
    else:
        for seed in [0, 1, 2]:
            for train_setting in ['train']:
                train(seed, train_setting)
