import torch
from config import parse_args

import torchvision.models as models
from domainnet_data import get_dataloader
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm import tqdm
import random
import numpy as np
from collections import defaultdict

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = 'cuda'
args = parse_args()

tokenizer = CLIPTokenizer.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)

def tokenize_captions(examples, is_train=False):
    captions = []
    for caption in examples:
        if isinstance(caption, str):
            captions.append(caption)
        elif isinstance(caption, (list, np.ndarray)):
            captions.append(random.choice(caption) if is_train else caption[0])
        else:
            raise ValueError(
                f"Caption column `{caption_column}` should contain either strings or lists of strings."
            )
    inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
    input_ids = inputs.input_ids
    return input_ids
        
def collate_fn(examples):
    pixel_values = torch.stack([example[0] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = [example[1] for example in examples]
    padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")
    domain_ids = torch.tensor([example[2] for example in examples])
    class_ids = torch.tensor([example[3] for example in examples])
    return {
        "pixel_values": pixel_values,
        "input_ids": padded_tokens.input_ids,
        "attention_mask": padded_tokens.attention_mask,
        "domain_ids": domain_ids,
        "class_ids": class_ids,
    }

def train(seed, train_setting):
    model = models.resnet18(pretrained=False)
    print(f"Training with seed {seed} and setting {train_setting}")
    setup_seed(seed)
    categories = args.categories

    train_dataloader = get_dataloader(
            args, args.train_batch_size, None,
            train_setting, tokenize_captions,  
            collate_fn, num_shot=100)    

    test_dataloader = get_dataloader(
            args, args.train_batch_size, None,
            'test', tokenize_captions,  
            collate_fn)    

    num_epochs = 100
    num_classes = 10
    optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    model.to(device)
    num_steps_per_epoch = len(train_dataloader)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #     optimizer, 
    #     num_steps_per_epoch * num_epochs,
    #     eta_min=0.001)

    for epoch in range(num_epochs):
        model.train()
        for batch in tqdm(train_dataloader):
            optimizer.zero_grad()
            inputs = batch["input_ids"].to(device)
            outputs = model(batch['pixel_values'].to(device))
            labels = batch['class_ids'].to(device)
            # Compute loss and perform backpropagation
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # scheduler.step()

        if (epoch+1) % 5 == 0 or epoch == num_epochs - 1:
            model.eval()
            with torch.no_grad():
                total_correct = 0
                total_samples = 0
                d_count = defaultdict(list)

                for batch in test_dataloader:
                    inputs = batch["pixel_values"].to(device)
                    outputs = model(inputs)
                    labels = batch['class_ids'].to(device)
                    preds = torch.argmax(outputs, dim=1)
                    # Compute loss and accuracy
                    total_correct += preds.eq(labels).sum().item()
                    total_samples += inputs.size(0)
                    for i, did in enumerate(batch['domain_ids']):
                        d_count[did.item()].append(labels[i].item() == preds[i].item())

            accuracy = total_correct / total_samples * 100
            for k in d_count.keys():
                acc = sum(d_count[k]) / len(d_count[k]) * 100 
                print(f"{args.domains[k]}: {round(acc, 3)}", end=", ")
            print(f"{train_setting}/Epoch {epoch+1}: Accuracy = {round(accuracy, 3)}")

if __name__=="__main__":
    for seed in [0, 1, 2]:
        for train_setting in ['train', 'train_syn_base', 'train_syn_spec', 'train_syn_opt']:
            train(seed, train_setting)
