import torch
from config import parse_args

import torchvision.models as models
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm import tqdm
import random
import numpy as np
from collections import defaultdict
import copy

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = 'cuda'
args = parse_args()

tokenizer = CLIPTokenizer.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)

def tokenize_captions(examples, is_train=False):
    captions = []
    for caption in examples:
        if isinstance(caption, str):
            captions.append(caption)
        elif isinstance(caption, (list, np.ndarray)):
            captions.append(random.choice(caption) if is_train else caption[0])
        else:
            raise ValueError(
                f"Caption column `{caption_column}` should contain either strings or lists of strings."
            )
    inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
    input_ids = inputs.input_ids
    return input_ids
        
def collate_fn(examples):
    pixel_values = torch.stack([example[0] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = [example[1] for example in examples]
    padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")
    domain_ids = torch.tensor([example[2] for example in examples])
    class_ids = torch.tensor([example[3] for example in examples])
    return {
        "pixel_values": pixel_values,
        "input_ids": padded_tokens.input_ids,
        "attention_mask": padded_tokens.attention_mask,
        "domain_ids": domain_ids,
        "class_ids": class_ids,
    }

def train(seed, train_setting):
    print(f"Training with seed {seed} and setting {train_setting}")
    setup_seed(seed)

    trainloaders = {}
    if args.dataset == 'domainnet':
        from domainnet_data import get_dataloader, get_dataloader_domain
    elif args.dataset == 'pacs':
        from pacs_data import get_dataloader, get_dataloader_domain
    elif args.dataset == 'officehome':
        from officehome_data import get_dataloader, get_dataloader_domain
    elif args.dataset == 'bloodmnist':
        from bloodmnist_data import get_dataloader, get_dataloader_domain
    elif args.dataset == 'dermamnist':
        from dermamnist_data import get_dataloader, get_dataloader_domain
    elif args.dataset == 'eurosat':
        from eurosat_data import get_dataloader, get_dataloader_domain


    for domain in args.domains:
        trainloader = get_dataloader_domain(
                args, args.train_batch_size, None,
                train_setting, domain, tokenize_captions,  
                collate_fn, num_shot=args.num_shot)            

        if args.dataset in ['bloodmnist', 'eurosat', 'dermamnist']:
            testloader = get_dataloader(
                    args, args.train_batch_size, None,
                    'test', tokenize_captions,  
                    collate_fn, num_shot=args.num_shot if args.dataset in ['pacs', 'eurosat'] else -1)
        else:
            testloader = get_dataloader_domain(
                    args, args.train_batch_size, None,
                    'test', domain, tokenize_captions,  
                    collate_fn, num_shot=args.num_shot)
                    
        num_epochs = 50
        num_classes = len(args.categories)

        server_model = models.resnet18(pretrained=args.pretrained)
        server_model.fc = torch.nn.Linear(server_model.fc.in_features, num_classes)
        server_model.to(device)
        state_dict = None

        for epoch in range(num_epochs):
            optimizer = torch.optim.SGD(server_model.parameters(), momentum=0.9, lr=0.01)
            criterion = torch.nn.CrossEntropyLoss()
            server_model.to(device)
            server_model.train()
            
            for batch in tqdm(trainloader):
                optimizer.zero_grad()
                inputs = batch["input_ids"].to(device)
                outputs = server_model(batch['pixel_values'].to(device))
                labels = batch['class_ids'].to(device)
                # Compute loss and perform backpropagation
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                # scheduler.step()

            # for eval
            if (epoch+1) % 5 == 0 or epoch == num_epochs - 1:
                server_model.eval()
                with torch.no_grad():
                    total_correct = 0
                    total_samples = 0
                    for batch in testloader:
                        inputs = batch["pixel_values"].to(device)
                        outputs = server_model(inputs)
                        labels = batch['class_ids'].to(device)
                        preds = torch.argmax(outputs, dim=1)
                        # Compute loss and accuracy
                        total_correct += preds.eq(labels).sum().item()
                        total_samples += inputs.size(0)

                accuracy = total_correct / total_samples * 100
                print(f"{train_setting}/Epoch {epoch+1}: {domain} Accuracy = {round(accuracy, 3)}")

if __name__=="__main__":
    for seed in [0, 1, 2]:
        for train_setting in ['train']:
            train(seed, train_setting)
