import torch
import timm
import torchvision.transforms as transforms
from torch import nn, optim
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import copy
from models.resnets import resnet20s



def cosine_similarity_torch(A, B):
    cosine_sim = torch.dot(A, B) / (torch.norm(A) * torch.norm(B))
    return cosine_sim

def cosine_similarity_manual(A, B):
    # Calculate dot product
    dot_product = sum(a * b for a, b in zip(A, B))

    # Calculate magnitude of A
    magnitude_A = sum(a ** 2 for a in A) ** 0.5

    # Calculate magnitude of B
    magnitude_B = sum(b ** 2 for b in B) ** 0.5

    # Calculate cosine similarity
    cosine_sim = dot_product / (magnitude_A * magnitude_B)

    return cosine_sim

def train(model, device, train_loader, optimizer, criterion):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # Convert labels to 0 and 1 for binary classification (-1 becomes 0, 1 stays 1)
        target = (target + 1) // 2
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

def eval(model, test_loader, criterion=nn.CrossEntropyLoss()):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            # Convert labels to 0 and 1 for testing
            target = (target + 1) // 2
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    test_acc = 100. * correct / len(test_loader.dataset)
    return test_acc, test_loss

def correlation_compute(model1, model2, test_loader):
    model1.eval()
    model2.eval()
    sum = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            out1 = model1(data)
            out2 = model2(data)
            out_new1=out1-torch.mean(out1,0)
            out_new2=out2-torch.mean(out2,0)
            for i in range(out_new1.shape[0]):
                sum+=cosine_similarity_torch(out_new1[i, 0:2], out_new2[i, 0:2])
    return sum/len(test_loader.dataset)

# Test function
if __name__ == "__main__":

    # Define device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    # Custom MNIST Dataset
    class CustomMNIST(MNIST):
        def __init__(self, *args, r1, flip_label=False, **kwargs):
            # Ensure no transform is applied by the parent class, handle it in __getitem__
            kwargs['transform'] = None
            super().__init__(*args, **kwargs)
            # Convert color to a 0-1 range
            # self.color = np.array(color, dtype=np.float32) / 255.0
            self.r1 = r1
            self.flip_label = flip_label

        def __getitem__(self, index):
            # Retrieve an image and its label from the standard MNIST dataset without any transform applied
            img, target = super(CustomMNIST, self).__getitem__(index)
            """
            # Modify the label: 1 if odd, -1 if even
            if self.flip_label:
                target = -1 if target % 2 != 0 else 1
            else:
                target = 1 if target % 2 != 0 else -1

            # Convert the grayscale image to RGB
            img = img.convert("RGB")

            # Convert the PIL image to a NumPy array
            img_array = np.array(img)

            # Apply the new color to non-black pixels
            mask = img_array[:, :, 0] > 0  # Get the mask of the digit
            # img_array[mask] = (self.color * 255).astype(np.uint8)  # Apply color to the digit
            col = np.random.binomial(1, self.r1, size=1)
            if target == 1:
                img_array[mask] = (255, 0, 0) if col == 1 else (0, 255, 0)
            else:
                img_array[mask] = (0, 255, 0) if col == 1 else (255, 0, 0)
            """
            # Convert the grayscale image to RGB
            img = img.convert("RGB")

            # Convert the PIL image to a NumPy array
            img_array = np.array(img)

            # Apply the new color to non-black pixels
            mask = img_array[:, :, 0] > 0 # Get the mask of the digit
            col=np.random.binomial(1, self.r1,1)
            if target % 2==1:
                img_array[mask]=(255,0,0) if col==1 else (0,255,0)
                target =1
            else:
                img_array[mask]=(0,255,0) if col==1 else (255,0,0,)
                target =-1

            # Convert the NumPy array back to a PIL image
            img = transforms.ToPILImage()(img_array)

            # Now apply the transformation for model input
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
            ])

            img = transform(img)
            return img, target

    class merged_NN(nn.Module):
        def __init__(self,new_model, init_model, tv1, tv2, device):
            super().__init__()
            self.lm1 = nn.parameter.Parameter(nn.init.uniform_(torch.empty(1), 0,1))
            self.lm2 = nn.parameter.Parameter(nn.init.uniform_(torch.empty(1), 0,1))
            self.model=new_model
            with torch.no_grad():
                self.original_parameters = {name: param.clone() for name, param in init_model.named_parameters()}
            self.device=device
            self.tv1=tv1
            self.tv2=tv2


        def forward(self, x):
            """
            start = 0
            for param in self.model.parameters():
                end = start + param.numel()
                param.data.add_((self.lm1.to(self.device) - 1) * self.tv1[start: end].view(param.size()))
                param.data.add_(self.lm2.to(self.device) * self.tv2[start: end].view(param.size()))
                start = end
            """
            start=0
            for name, param in self.model.named_parameters():
                end=start+param.numel()
                param.data = self.original_parameters[name].data + (self.lm1.to(self.device) - 1) * self.tv1[start: end].view(
                        param.size()) + self.lm2.to(self.device) * self.tv2[start: end].view(param.size())
                start=end
            y=self.model(x)
            return y


    # Experiment 1, Task 1: [255,0,0], flip: False, Task 2: [251,44,0], flip: False
    # Experiment 2, Task 1: [255,0,0], flip: False, Task 2: [0,255,0], flip: False
    # Experiment 3, Task 1: [255,0,0], flip: False, Task 2: [251,44,0], flip: Truf'e
    num=0

    for ratio1, ratio2, flip1, flip2 in [(0.87,0.05,False,False), (0.87,0.05, False, False), (0.87,0.05, False, False),
                                         (0.86,0.05,False,False), (0.95,0.05,False,False), (0.95,0.07,False,False),
                                         (0.95,0.07,False,False), (0.95,0.07,False,False)]:

        # Usage example:
        train_set_task1 = CustomMNIST(root='./data', train=True, download=True, r1=ratio1, flip_label=flip1)
        train_set_task2 = CustomMNIST(root='./data', train=True, download=True, r1=ratio2, flip_label=flip2)
        train_set_task3 = CustomMNIST(root='./data', train=True, download=True, r1=0.9, flip_label=flip2)

        test_set_task1 = CustomMNIST(root='./data', train=False, download=True, r1=ratio1, flip_label=flip1)
        test_set_task2 = CustomMNIST(root='./data', train=False, download=True, r1=ratio2, flip_label=flip2)
        test_set_task3 = CustomMNIST(root='./data', train=False, download=True, r1=0.9, flip_label=False)
        test_set_task_alpha = CustomMNIST(root='./data', train=False, download=True, r1=0.5, flip_label=False)

        # Data loaders
        train_loader_task1 = DataLoader(train_set_task1, batch_size=256, shuffle=True)
        train_loader_task2 = DataLoader(train_set_task2, batch_size=256, shuffle=True)
        train_loader_task3 = DataLoader(train_set_task3, batch_size=256, shuffle=True)
        test_loader_task1 = DataLoader(test_set_task1, batch_size=256, shuffle=False)
        test_loader_task2 = DataLoader(test_set_task2, batch_size=256, shuffle=False)
        test_loader_task_eval = DataLoader(test_set_task3, batch_size=256, shuffle=False)
        test_loader_task_ev_alpha = DataLoader(test_set_task_alpha, batch_size=256, shuffle=False)

        # Model
        """
        model_init = resnet20s(num_classes=2).to(device)
        model_task1 = copy.deepcopy(model_init).to(device)
        model_task2 = copy.deepcopy(model_init).to(device)

        for param in model_task1.fc.parameters():
            param.requires_grad = False
        for param in model_task2.fc.parameters():
            param.requires_grad = False
        """
        # Model
        model_init = timm.create_model("vit_small_patch16_224.augreg_in21k_ft_in1k", pretrained=True).to(device)
        model_task1 = copy.deepcopy(model_init).to(device)
        model_task2 = copy.deepcopy(model_init).to(device)
        model_task3 = copy.deepcopy(model_init).to(device)

        for param in model_task1.head.parameters():
            param.requires_grad = False
        for param in model_task2.head.parameters():
            param.requires_grad = False

        # Loss and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer_task1 = optim.Adam(model_task1.parameters(), lr=0.001)
        optimizer_task2 = optim.Adam(model_task2.parameters(), lr=0.001)
        optimizer_task3 = optim.Adam(model_task3.parameters(), lr=0.001)

        # Train red
        for epoch in tqdm(range(1)):
            train(model_task1, device, train_loader_task1, optimizer_task1, criterion)
            train(model_task2, device, train_loader_task2, optimizer_task2, criterion)
            train(model_task3, device, train_loader_task3, optimizer_task3, criterion)

        test_acc_task1, test_loss_task1 = eval(model_task1, test_loader_task_eval, nn.CrossEntropyLoss())
        test_acc_task2, test_loss_task2 = eval(model_task2, test_loader_task_eval, nn.CrossEntropyLoss())
        test_acc_task3, test_loss_task3 = eval(model_task3, test_loader_task_eval, nn.CrossEntropyLoss())

        print(f'Task accuracy single 1: {test_acc_task1}')
        print(f'Task accuracy singl 2: {test_acc_task2}')
        print(f'Task accuracy FT: {test_acc_task3}')

        model_init_params = torch.cat([p.view(-1) for p in model_init.parameters()])
        model_task1_params = torch.cat([p.view(-1) for p in model_task1.parameters()])
        model_task_vector1 = model_task1_params - model_init_params
        model_task2_params = torch.cat([p.view(-1) for p in model_task2.parameters()])
        model_task_vector2 = model_task2_params - model_init_params

        # Test
        results = []

        alpha1 = correlation_compute(model_task1, model_task3, test_loader_task_ev_alpha)
        alpha2 = correlation_compute(model_task2, model_task3, test_loader_task_ev_alpha)
        alpha12 = correlation_compute(model_task1, model_task2, test_loader_task_ev_alpha)
        print(f'alpha1: {alpha1}; alpha2: {alpha2}; alpha12: {alpha12}')


        # Calculate the cosine similarity
        #theta_cosine_similarity = cosine_similarity_torch(model_task_vector1, model_task_vector2)
        #print(f'Theta cosine similarity: {theta_cosine_similarity}')
        #alpha = correlation_compute(model_task1, model_task2, test_loader_task_eval)
        #print(f'alpha: {alpha}')


        if alpha2<-0.1:
            new_model0 = copy.deepcopy(model_task1).to(device)
            init_model = copy.deepcopy(new_model0).to(device)
            for params in init_model.parameters():
                params.requires_grad = False

            merging_model = merged_NN(new_model0, init_model, model_task_vector1, model_task_vector2, device)

            optimizer_merge = optim.Adam(merging_model.parameters(), lr=0.001)
            for epoch in tqdm(range(1)):
                train(merging_model, device, train_loader_task3, optimizer_merge, criterion)
            param_list = list(merging_model.parameters())
            test_acc_task3, test_loss_task3 = eval(merging_model, test_loader_task_eval, nn.CrossEntropyLoss())

            print(f'Task accuracy: {test_acc_task3}')
            print(f'lambda1: {param_list[0][0].detach().numpy()}, lambda2: {param_list[1][0].detach().numpy()},')

            acc_max = 0
            Lm1 = 0
            Lm2 = 0
            map = np.zeros((21, 26))
            for lmbda1 in range(-20, 31, 2):
                lmbda1 /= 10.0
                for lmbda2 in range(-20, 21, 2):
                    lmbda2 /= 10.0
                    # lmbda1 = param_list[0][0].detach().numpy().item()
                    # lmbda2 = param_list[1][0].detach().numpy().item()
                    # lmbda = 0
                    new_model = copy.deepcopy(model_task1).to(device)
                    # Update the model parameters
                    start = 0
                    for param in new_model.parameters():
                        end = start + param.numel()
                        param.data.add_((lmbda1 - 1) * model_task_vector1[start: end].view(param.size()))
                        param.data.add_(lmbda2 * model_task_vector2[start: end].view(param.size()))
                        start = end

                    # Evaluate the new model
                    test_acc_try, test_loss_task3 = eval(new_model, test_loader_task_eval, nn.CrossEntropyLoss())
                    # test_acc_task1, test_loss_task1 = eval(new_model, test_loader_task1, nn.CrossEntropyLoss())
                    # test_acc_task2, test_loss_task2 = eval(new_model, test_loader_task2, nn.CrossEntropyLoss())
                    print(f'Lambda1: {lmbda1}, Lambda2: {lmbda2},Task accuracy: {test_acc_try}')
                    map[20 - int(lmbda2 * 5 + 10), int(lmbda1 * 5 + 10)] = test_acc_try
                    if acc_max < test_acc_try:
                        acc_max = test_acc_try
                        Lm1 = lmbda1
                        Lm2 = lmbda2
            print(f'Max: Lambda1: {Lm1}, Lambda2: {Lm2},Task accuracy: {acc_max}')
            np.save(f'data/map_{num}.npy', map)
            print(map)
            num += 1

            # results.append((lmbda1, lmbda2, test_acc_task1, test_acc_task2))



